Merge tag '1.8.1' into merge-tag-1.8.1

# Conflicts:
#	.gitignore
#	README.md
#	api/.env.example
#	api/Dockerfile
#	api/commands.py
#	api/configs/app_config.py
#	api/controllers/console/__init__.py
#	api/controllers/console/apikey.py
#	api/controllers/console/app/statistic.py
#	api/controllers/service_api/app/app.py
#	api/controllers/service_api/app/audio.py
#	api/controllers/service_api/app/completion.py
#	api/controllers/service_api/app/conversation.py
#	api/controllers/service_api/app/file.py
#	api/controllers/service_api/app/message.py
#	api/controllers/service_api/app/workflow.py
#	api/controllers/service_api/wraps.py
#	api/controllers/web/completion.py
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/agent_chat/app_generator.py
#	api/core/app/apps/workflow/app_generator.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_manage.py
#	api/core/helper/code_executor/code_executor.py
#	api/core/tools/builtin_tool/providers/code/tools/simple_code.py
#	api/core/workflow/nodes/code/code_node.py
#	api/docker/entrypoint.sh
#	api/events/event_handlers/__init__.py
#	api/extensions/ext_celery.py
#	api/extensions/ext_commands.py
#	api/models/model.py
#	api/models/workflow.py
#	api/poetry.lock
#	api/pyproject.toml
#	api/services/app_service.py
#	api/services/feature_service.py
#	api/services/workspace_service.py
#	web/.env.example
#	web/Dockerfile
#	web/app/(commonLayout)/apps/Apps.tsx
#	web/app/components/apps/app-card.tsx
#	web/app/components/base/chat/embedded-chatbot/index.tsx
#	web/app/components/base/mermaid/index.tsx
#	web/app/components/develop/index.tsx
#	web/app/components/develop/secret-key/secret-key-modal.tsx
#	web/app/components/develop/secret-key/style.module.css
#	web/app/components/develop/template/template.zh.mdx
#	web/app/components/explore/app-list/index.tsx
#	web/app/components/explore/category.tsx
#	web/app/components/explore/sidebar/index.tsx
#	web/app/components/header/account-dropdown/index.tsx
#	web/app/components/header/index.tsx
#	web/app/components/share/utils.ts
#	web/app/layout.tsx
#	web/app/signin/components/mail-and-password-auth.tsx
#	web/app/signin/normal-form.tsx
#	web/app/signin/page.module.css
#	web/context/app-context.tsx
#	web/i18n/i18next-config.ts
#	web/i18n/ja-JP/login.ts
#	web/i18n/ko-KR/login.ts
#
    if dify_config.WORKFLOW_LOG_CLEANUP_ENABLED:
        # 2:00 AM every day
        imports.append("schedule.clean_workflow_runlogs_precise")
        beat_schedule["clean_workflow_runlogs_precise"] = {
            "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
            "schedule": crontab(minute="0", hour="2"),
        }	web/package.json
#	web/pnpm-lock.yaml
#	web/types/feature.ts
This commit is contained in:
npc0-hue
2025-09-25 15:55:13 +08:00
3566 changed files with 237942 additions and 62178 deletions
+23 -24
View File
@@ -1,18 +1,20 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def add_document_to_index_task(dataset_document_id: str):
@@ -22,24 +24,29 @@ def add_document_to_index_task(dataset_document_id: str):
Usage: add_document_to_index_task.delay(dataset_document_id)
"""
logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green"))
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
start_at = time.perf_counter()
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
logging.info(click.style("Document not found: {}".format(dataset_document_id), fg="red"))
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
db.session.close()
return
if dataset_document.indexing_status != "completed":
db.session.close()
return
indexing_cache_key = "document_{}_indexing".format(dataset_document.id)
indexing_cache_key = f"document_{dataset_document.id}_indexing"
try:
dataset = dataset_document.dataset
if not dataset:
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == False,
DocumentSegment.status == "completed",
@@ -77,43 +84,35 @@ def add_document_to_index_task(dataset_document_id: str):
document.children = child_documents
documents.append(document)
dataset = dataset_document.dataset
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).filter(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
# update segment to enable
db.session.query(DocumentSegment).filter(DocumentSegment.document_id == dataset_document.id).update(
db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(
"Document added to index: {} latency: {}".format(dataset_document.id, end_at - start_at), fg="green"
)
logger.info(
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
)
except Exception as e:
logging.exception("add document to index failed")
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
dataset_document.status = "error"
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
@@ -2,7 +2,7 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
@@ -10,6 +10,8 @@ from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def add_annotation_to_index_task(
@@ -25,7 +27,7 @@ def add_annotation_to_index_task(
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style("Start build index for annotation: {}".format(annotation_id), fg="green"))
logger.info(click.style(f"Start build index for annotation: {annotation_id}", fg="green"))
start_at = time.perf_counter()
try:
@@ -48,13 +50,13 @@ def add_annotation_to_index_task(
vector.create([document], duplicate_check=True)
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at),
f"Build index successful for annotation: {annotation_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logging.exception("Build index for annotation failed")
logger.exception("Build index for annotation failed")
finally:
db.session.close()
@@ -2,7 +2,7 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
@@ -13,6 +13,8 @@ from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, user_id: str):
@@ -25,11 +27,11 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
:param user_id: user_id
"""
logging.info(click.style("Start batch import annotation: {}".format(job_id), fg="green"))
logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
# get app info
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if app:
try:
@@ -48,7 +50,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
@@ -74,7 +76,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Build index successful for batch import annotation: {} latency: {}".format(
job_id, end_at - start_at
@@ -85,8 +87,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
except Exception as e:
db.session.rollback()
redis_client.setex(indexing_cache_key, 600, "error")
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
redis_client.setex(indexing_error_msg_key, 600, str(e))
logging.exception("Build index for batch import annotations failed")
logger.exception("Build index for batch import annotations failed")
finally:
db.session.close()
@@ -2,20 +2,22 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, collection_binding_id: str):
"""
Async delete annotation index task
"""
logging.info(click.style("Start delete app annotation index: {}".format(app_id), fg="green"))
logger.info(click.style(f"Start delete app annotation index: {app_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
@@ -33,12 +35,10 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete_by_metadata_field("annotation_id", annotation_id)
except Exception:
logging.exception("Delete annotation index failed when annotation deleted.")
logger.exception("Delete annotation index failed when annotation deleted.")
end_at = time.perf_counter()
logging.info(
click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green")
)
except Exception as e:
logging.exception("Annotation deleted index failed")
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Annotation deleted index failed")
finally:
db.session.close()
@@ -2,7 +2,8 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from sqlalchemy import exists, select
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@@ -10,33 +11,33 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
"""
Async enable annotation reply task
"""
logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green"))
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count()
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logging.info(click.style("App not found: {}".format(app_id), fg="red"))
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
return
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if not app_annotation_setting:
logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red"))
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
db.session.close()
return
disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id))
disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id))
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
try:
dataset = Dataset(
@@ -47,11 +48,11 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
)
try:
if annotations_count > 0:
if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete()
except Exception:
logging.exception("Delete annotation index failed when annotation deleted.")
logger.exception("Delete annotation index failed when annotation deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
# delete annotation setting
@@ -59,13 +60,11 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green")
)
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logging.exception("Annotation batch deleted index failed")
logger.exception("Annotation batch deleted index failed")
redis_client.setex(disable_app_annotation_job_key, 600, "error")
disable_app_annotation_error_key = "disable_app_annotation_error_{}".format(str(job_id))
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)
@@ -1,18 +1,20 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def enable_annotation_reply_task(
@@ -27,28 +29,26 @@ def enable_annotation_reply_task(
"""
Async enable annotation reply task
"""
logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green"))
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if not app:
logging.info(click.style("App not found: {}".format(app_id), fg="red"))
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
return
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id))
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
)
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
@@ -70,11 +70,11 @@ def enable_annotation_reply_task(
try:
old_vector.delete()
except Exception as e:
logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red"))
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
annotation_setting.updated_at = naive_utc_now()
db.session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
@@ -106,18 +106,16 @@ def enable_annotation_reply_task(
try:
vector.delete_by_metadata_field("app_id", app_id)
except Exception as e:
logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red"))
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
vector.create(documents)
db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
end_at = time.perf_counter()
logging.info(
click.style("App annotations added to index: {} latency: {}".format(app_id, end_at - start_at), fg="green")
)
logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logging.exception("Annotation batch created index failed")
logger.exception("Annotation batch created index failed")
redis_client.setex(enable_app_annotation_job_key, 600, "error")
enable_app_annotation_error_key = "enable_app_annotation_error_{}".format(str(job_id))
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
db.session.rollback()
finally:
@@ -2,7 +2,7 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
@@ -10,6 +10,8 @@ from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def update_annotation_to_index_task(
@@ -25,7 +27,7 @@ def update_annotation_to_index_task(
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style("Start update index for annotation: {}".format(annotation_id), fg="green"))
logger.info(click.style(f"Start update index for annotation: {annotation_id}", fg="green"))
start_at = time.perf_counter()
try:
@@ -49,13 +51,13 @@ def update_annotation_to_index_task(
vector.delete_by_metadata_field("annotation_id", annotation_id)
vector.add_texts([document])
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at),
f"Build index successful for annotation: {annotation_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logging.exception("Build index for annotation failed")
logger.exception("Build index for annotation failed")
finally:
db.session.close()
+15 -12
View File
@@ -2,7 +2,7 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -11,6 +11,8 @@ from extensions.ext_storage import storage
from models.dataset import Dataset, DocumentSegment
from models.model import UploadFile
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
@@ -23,16 +25,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
Usage: batch_clean_document_task.delay(document_ids, dataset_id)
"""
logging.info(click.style("Start batch clean documents when documents deleted", fg="green"))
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -42,37 +44,38 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logging.exception(
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
image_upload_file_is: %s",
upload_file_id,
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.commit()
if file_ids:
files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all()
files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id))
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
db.session.delete(file)
db.session.commit()
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Cleaned documents when documents deleted latency: {}".format(end_at - start_at),
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logging.exception("Cleaned documents when documents deleted failed")
logger.exception("Cleaned documents when documents deleted failed")
finally:
db.session.close()
+45 -19
View File
@@ -1,26 +1,33 @@
import datetime
import logging
import tempfile
import time
import uuid
from pathlib import Path
import click
from celery import shared_task # type: ignore
from sqlalchemy import func, select
import pandas as pd
from celery import shared_task
from sqlalchemy import func
from sqlalchemy.orm import Session
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
from services.vector_service import VectorService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_create_segment_to_index_task(
job_id: str,
content: list,
upload_file_id: str,
dataset_id: str,
document_id: str,
tenant_id: str,
@@ -29,18 +36,18 @@ def batch_create_segment_to_index_task(
"""
Async batch create segment to index
:param job_id:
:param content:
:param upload_file_id:
:param dataset_id:
:param document_id:
:param tenant_id:
:param user_id:
Usage: batch_create_segment_to_index_task.delay(job_id, content, dataset_id, document_id, tenant_id, user_id)
Usage: batch_create_segment_to_index_task.delay(job_id, upload_file_id, dataset_id, document_id, tenant_id, user_id)
"""
logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green"))
logger.info(click.style(f"Start batch create segment jobId: {job_id}", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = "segment_batch_import_{}".format(job_id)
indexing_cache_key = f"segment_batch_import_{job_id}"
try:
with Session(db.engine) as session:
@@ -58,6 +65,29 @@ def batch_create_segment_to_index_task(
or dataset_document.indexing_status != "completed"
):
raise ValueError("Document is not available.")
upload_file = session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
# Skip the first row
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
@@ -68,11 +98,6 @@ def batch_create_segment_to_index_task(
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0
segments_to_insert: list[str] = []
max_position_stmt = select(func.max(DocumentSegment.position)).where(
DocumentSegment.document_id == dataset_document.id
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
@@ -86,7 +111,7 @@ def batch_create_segment_to_index_task(
segment_hash = helper.generate_text_hash(content) # type: ignore
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == dataset_document.id)
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
@@ -100,9 +125,9 @@ def batch_create_segment_to_index_task(
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
indexing_at=naive_utc_now(),
status="completed",
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
@@ -111,6 +136,7 @@ def batch_create_segment_to_index_task(
db.session.add(segment_document)
document_segments.append(segment_document)
# update document word count
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
# add index to db
@@ -118,14 +144,14 @@ def batch_create_segment_to_index_task(
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Segment batch created job: {} latency: {}".format(job_id, end_at - start_at),
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logging.exception("Segments batch created index failed")
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
finally:
db.session.close()
+54 -26
View File
@@ -2,10 +2,10 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import (
@@ -20,6 +20,8 @@ from models.dataset import (
)
from models.model import UploadFile
logger = logging.getLogger(__name__)
# Add import statement for ValueError
@shared_task(queue="dataset")
@@ -42,7 +44,7 @@ def clean_dataset_task(
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style("Start clean dataset when dataset deleted: {}".format(dataset_id), fg="green"))
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
@@ -53,18 +55,37 @@ def clean_dataset_task(
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
if documents is None or len(documents) == 0:
logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green"))
else:
logging.info(click.style("Cleaning documents for dataset: {}".format(dataset_id), fg="green"))
# Specify the index type before initializing the index processor
if doc_form is None:
raise ValueError("Index type must be specified.")
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexType
doc_form = IndexType.PARAGRAPH_INDEX
logger.info(
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
)
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
try:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
except Exception:
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
# Continue with document and segment deletion even if vector cleanup fails
logger.info(
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
)
if documents is None or len(documents) == 0:
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
else:
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
for document in documents:
db.session.delete(document)
@@ -72,25 +93,26 @@ def clean_dataset_task(
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
image_upload_file_is: %s",
upload_file_id,
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
# delete dataset metadata
db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == dataset_id).delete()
db.session.query(DatasetMetadataBinding).filter(DatasetMetadataBinding.dataset_id == dataset_id).delete()
db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete files
if documents:
for document in documents:
@@ -102,7 +124,7 @@ def clean_dataset_task(
file_id = data_source_info["upload_file_id"]
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
)
if not file:
@@ -114,12 +136,18 @@ def clean_dataset_task(
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(
"Cleaned dataset when dataset deleted: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"
)
logger.info(
click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
)
except Exception:
logging.exception("Cleaned dataset when dataset deleted failed")
# Add rollback to prevent dirty session state in case of exceptions
# This ensures the database session is properly cleaned up
try:
db.session.rollback()
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
except Exception:
logger.exception("Failed to rollback database session")
logger.exception("Cleaned dataset when dataset deleted failed")
finally:
db.session.close()
+18 -14
View File
@@ -3,15 +3,17 @@ import time
from typing import Optional
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]):
@@ -24,16 +26,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
Usage: clean_document_task.delay(document_id, dataset_id)
"""
logging.info(click.style("Start clean document when document deleted: {}".format(document_id), fg="green"))
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -43,44 +45,46 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
image_upload_file_is: %s",
upload_file_id,
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.commit()
if file_id:
file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logging.exception("Delete file failed when document deleted, file_id: {}".format(file_id))
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
db.session.delete(file)
db.session.commit()
# delete dataset metadata binding
db.session.query(DatasetMetadataBinding).filter(
db.session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
db.session.commit()
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Cleaned document when document deleted: {} latency: {}".format(document_id, end_at - start_at),
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logging.exception("Cleaned document when document deleted failed")
logger.exception("Cleaned document when document deleted failed")
finally:
db.session.close()
+9 -9
View File
@@ -2,12 +2,14 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def clean_notion_document_task(document_ids: list[str], dataset_id: str):
@@ -18,23 +20,21 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
Usage: clean_notion_document_task.delay(document_ids, dataset_id)
"""
logging.info(
click.style("Start clean document when import form notion document deleted: {}".format(dataset_id), fg="green")
)
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
for document_id in document_ids:
document = db.session.query(Document).filter(Document.id == document_id).first()
document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document)
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
@@ -43,7 +43,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
@@ -52,6 +52,6 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
)
)
except Exception:
logging.exception("Cleaned document when import form notion document deleted failed")
logger.exception("Cleaned document when import form notion document deleted failed")
finally:
db.session.close()
+27 -24
View File
@@ -1,17 +1,19 @@
import datetime
import logging
import time
from typing import Optional
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None):
@@ -21,27 +23,29 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
:param keywords:
Usage: create_segment_to_index_task.delay(segment_id)
"""
logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green"))
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "waiting":
db.session.close()
return
indexing_cache_key = "segment_{}_indexing".format(segment.id)
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
# update segment status to indexing
update_params = {
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: naive_utc_now(),
}
)
db.session.commit()
document = Document(
page_content=segment.content,
@@ -56,17 +60,17 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
dataset = segment.dataset
if not dataset:
logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_type = dataset.doc_form
@@ -74,21 +78,20 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
index_processor.load(dataset, [document])
# update segment to completed
update_params = {
DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style("Segment created to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green")
)
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logging.exception("create segment to index failed")
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
+19 -18
View File
@@ -1,8 +1,9 @@
import logging
import time
from typing import Literal
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -11,20 +12,22 @@ from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: str):
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
"""
Async deal dataset from index
:param dataset_id: dataset_id
:param action: action
Usage: deal_dataset_vector_index_task.delay(dataset_id, action)
"""
logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green"))
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
@@ -35,7 +38,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
elif action == "add":
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -46,7 +49,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
@@ -56,7 +59,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
# add from vector index
segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
@@ -76,19 +79,19 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -100,7 +103,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
@@ -113,7 +116,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
try:
segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
@@ -148,12 +151,12 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
@@ -162,10 +165,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
)
logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logging.exception("Deal dataset vector index failed")
logger.exception("Deal dataset vector index failed")
finally:
db.session.close()
+5 -5
View File
@@ -1,6 +1,6 @@
import logging
from celery import shared_task # type: ignore
from celery import shared_task
from extensions.ext_database import db
from models.account import Account
@@ -12,15 +12,15 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
account = db.session.query(Account).filter(Account.id == account_id).first()
account = db.session.query(Account).where(Account.id == account_id).first()
try:
BillingService.delete_account(account_id)
except Exception as e:
logger.exception(f"Failed to delete account {account_id} from billing service.")
except Exception:
logger.exception("Failed to delete account %s from billing service.", account_id)
raise
if not account:
logger.error(f"Account {account_id} not found.")
logger.error("Account %s not found.", account_id)
return
# send success email
send_deletion_success_task.delay(account.email)
+70
View File
@@ -0,0 +1,70 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_database import db
from models import ConversationVariable
from models.model import Message, MessageAnnotation, MessageFeedback
from models.tools import ToolConversationVariables, ToolFile
from models.web import PinnedConversation
logger = logging.getLogger(__name__)
@shared_task(queue="conversation")
def delete_conversation_related_data(conversation_id: str) -> None:
"""
Delete related data conversation in correct order from datatbase to respect foreign key constraints
Args:
conversation_id: conversation Id
"""
logger.info(
click.style(f"Starting to delete conversation data from db for conversation_id {conversation_id}", fg="green")
)
start_at = time.perf_counter()
try:
db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(ToolConversationVariables).where(
ToolConversationVariables.conversation_id == conversation_id
).delete(synchronize_session=False)
db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception as e:
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
db.session.rollback()
raise e
finally:
db.session.close()
+8 -6
View File
@@ -2,12 +2,14 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str):
@@ -19,14 +21,14 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
Usage: delete_segment_from_index_task.delay(index_node_ids, dataset_id, document_id)
"""
logging.info(click.style("Start delete segment from index", fg="green"))
logger.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return
dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
dataset_document = db.session.query(Document).where(Document.id == document_id).first()
if not dataset_document:
return
@@ -38,8 +40,8 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter()
logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
except Exception:
logging.exception("delete segment from index failed")
logger.exception("delete segment from index failed")
finally:
db.session.close()
+13 -13
View File
@@ -2,13 +2,15 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def disable_segment_from_index_task(segment_id: str):
@@ -18,37 +20,37 @@ def disable_segment_from_index_task(segment_id: str):
Usage: disable_segment_from_index_task.delay(segment_id)
"""
logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green"))
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "completed":
logging.info(click.style("Segment is not completed, disable is not allowed: {}".format(segment_id), fg="red"))
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
db.session.close()
return
indexing_cache_key = "segment_{}_indexing".format(segment.id)
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
dataset = segment.dataset
if not dataset:
logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_type = dataset_document.doc_form
@@ -56,11 +58,9 @@ def disable_segment_from_index_task(segment_id: str):
index_processor.clean(dataset, [segment.index_node_id])
end_at = time.perf_counter()
logging.info(
click.style("Segment removed from index: {} latency: {}".format(segment.id, end_at - start_at), fg="green")
)
logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception:
logging.exception("remove segment from index failed")
logger.exception("remove segment from index failed")
segment.enabled = True
db.session.commit()
finally:
+12 -10
View File
@@ -2,7 +2,7 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -10,6 +10,8 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str):
@@ -23,20 +25,20 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
db.session.close()
return
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
db.session.close()
return
# sync index processor
@@ -44,7 +46,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
@@ -61,10 +63,10 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green"))
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:
# update segment error msg
db.session.query(DocumentSegment).filter(
db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
@@ -78,6 +80,6 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)
db.session.close()
+28 -22
View File
@@ -1,17 +1,19 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.source import DataSourceOauthBinding
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def document_indexing_sync_task(dataset_id: str, document_id: str):
@@ -22,13 +24,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
Usage: document_indexing_sync_task.delay(dataset_id, document_id)
"""
logging.info(click.style("Start sync document: {}".format(document_id), fg="green"))
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
@@ -44,14 +46,18 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
@@ -68,18 +74,18 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = naive_utc_now()
db.session.commit()
# delete all document segment and index
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
@@ -89,7 +95,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
db.session.delete(segment)
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
@@ -98,16 +104,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
)
)
except Exception:
logging.exception("Cleaned document when document update data source or process rule failed")
logger.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logging.info(
click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")
)
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
finally:
db.session.close()
+14 -12
View File
@@ -1,16 +1,18 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def document_indexing_task(dataset_id: str, document_ids: list):
@@ -24,9 +26,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow"))
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
db.session.close()
return
# check document limit
@@ -48,27 +50,27 @@ def document_indexing_task(dataset_id: str, document_ids: list):
except Exception as e:
for document_id in document_ids:
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
db.session.close()
return
for document_id in document_ids:
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
@@ -77,10 +79,10 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
+15 -13
View File
@@ -1,15 +1,17 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def document_indexing_update_task(dataset_id: str, document_id: str):
@@ -20,30 +22,30 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
Usage: document_indexing_update_task.delay(dataset_id, document_id)
"""
logging.info(click.style("Start update document: {}".format(document_id), fg="green"))
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = naive_utc_now()
db.session.commit()
# delete all document segment and index
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -54,7 +56,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
@@ -63,16 +65,16 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
)
)
except Exception:
logging.exception("Cleaned document when document update data source or process rule failed")
logger.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green"))
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
finally:
db.session.close()
+67 -63
View File
@@ -1,17 +1,19 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
@@ -25,80 +27,82 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset is None:
logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
current = int(getattr(vector_space, "size", 0) or 0)
limit = int(getattr(vector_space, "limit", 0) or 0)
if limit > 0 and (current + count) > limit:
raise ValueError(
"Your total number of documents plus the number of uploads have exceeded the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
except Exception as e:
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
return
finally:
db.session.close()
for document_id in document_ids:
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
+15 -15
View File
@@ -1,17 +1,19 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def enable_segment_to_index_task(segment_id: str):
@@ -21,21 +23,21 @@ def enable_segment_to_index_task(segment_id: str):
Usage: enable_segment_to_index_task.delay(segment_id)
"""
logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green"))
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "completed":
logging.info(click.style("Segment is not completed, enable is not allowed: {}".format(segment_id), fg="red"))
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
db.session.close()
return
indexing_cache_key = "segment_{}_indexing".format(segment.id)
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
document = Document(
@@ -51,17 +53,17 @@ def enable_segment_to_index_task(segment_id: str):
dataset = segment.dataset
if not dataset:
logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan"))
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
@@ -85,13 +87,11 @@ def enable_segment_to_index_task(segment_id: str):
index_processor.load(dataset, [document])
end_at = time.perf_counter()
logging.info(
click.style("Segment enabled to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green")
)
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logging.exception("enable segment to index failed")
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
+16 -14
View File
@@ -1,18 +1,20 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str):
@@ -25,19 +27,19 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
db.session.close()
return
# sync index processor
@@ -45,7 +47,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
@@ -53,7 +55,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
.all()
)
if not segments:
logging.info(click.style("Segments not found: {}".format(segment_ids), fg="cyan"))
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
db.session.close()
return
@@ -91,11 +93,11 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
index_processor.load(dataset, documents)
end_at = time.perf_counter()
logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green"))
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
except Exception as e:
logging.exception("enable segments to index failed")
logger.exception("enable segments to index failed")
# update segment error msg
db.session.query(DocumentSegment).filter(
db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
@@ -103,13 +105,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
{
"error": str(e),
"status": "error",
"disabled_at": datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
"disabled_at": naive_utc_now(),
"enabled": False,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)
db.session.close()
@@ -12,15 +12,15 @@ from extensions.ext_database import db
from models.account import Account
from models.account_money_extend import AccountMoneyExtend
from models.api_token_money_extend import ApiTokenMessageJoinsExtend, ApiTokenMoneyExtend
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.model_extend import EndUserAccountJoinsExtend
from models.workflow import WorkflowNodeExecution
from models.workflow import WorkflowNodeExecutionModel
@shared_task(queue="extend_high", bind=True, max_retries=3)
def update_account_money_when_workflow_node_execution_created_extend(self, workflow_node_execution_dict: dict):
""" """
workflowNodeExecution = WorkflowNodeExecution(**workflow_node_execution_dict)
workflowNodeExecution = WorkflowNodeExecutionModel(**workflow_node_execution_dict)
# 非大模型则跳过
if workflowNodeExecution.node_type != NodeType.LLM.value:
return
@@ -41,7 +41,7 @@ def update_account_money_when_workflow_node_execution_created_extend(self, workf
# web应用的请求,created_by记录的是登录账号的ID,可以拿这个ID来扣钱
# API调用,created_by记录的是节点登录账号ID,真正需要扣钱的在关联表EndUserAccountJoinsExtend,需要多做一层查询
payerId = workflowNodeExecution.created_by # 付钱的ID
if workflowNodeExecution.created_by_role == CreatedByRole.END_USER.value:
if workflowNodeExecution.created_by_role == CreatorUserRole.END_USER.value:
account = db.session.query(Account).filter(Account.id == workflowNodeExecution.created_by).first()
if not account:
end_user_account_joins = (
+58 -38
View File
@@ -2,59 +2,79 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from flask import render_template
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_deletion_success_task(to):
"""Send email to user regarding account deletion."""
if not mail.is_inited():
return
logging.info(click.style(f"Start send account deletion success email to {to}", fg="green"))
start_at = time.perf_counter()
try:
html_content = render_template(
"delete_account_success_template_en-US.html",
to=to,
email=to,
)
mail.send(to=to, subject="Your Dify.AI Account Has Been Successfully Deleted", html=html_content)
end_at = time.perf_counter()
logging.info(
click.style(
"Send account deletion success email to {}: latency: {}".format(to, end_at - start_at), fg="green"
)
)
except Exception:
logging.exception("Send account deletion success email to {} failed".format(to))
@shared_task(queue="mail")
def send_account_deletion_verification_code(to, code):
"""Send email to user regarding account deletion verification code.
def send_deletion_success_task(to: str, language: str = "en-US") -> None:
"""
Send account deletion success email with internationalization support.
Args:
to (str): Recipient email address
code (str): Verification code
to: Recipient email address
language: Language code for email localization
"""
if not mail.is_inited():
return
logging.info(click.style(f"Start send account deletion verification code email to {to}", fg="green"))
logger.info(click.style(f"Start send account deletion success email to {to}", fg="green"))
start_at = time.perf_counter()
try:
html_content = render_template("delete_account_code_email_template_en-US.html", to=to, code=code)
mail.send(to=to, subject="Dify.AI Account Deletion and Verification", html=html_content)
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.ACCOUNT_DELETION_SUCCESS,
language_code=language,
to=to,
template_context={
"to": to,
"email": to,
},
)
end_at = time.perf_counter()
logging.info(
logger.info(
click.style(f"Send account deletion success email to {to}: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send account deletion success email to %s failed", to)
@shared_task(queue="mail")
def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US") -> None:
"""
Send account deletion verification code email with internationalization support.
Args:
to: Recipient email address
code: Verification code
language: Language code for email localization
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start send account deletion verification code email to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.ACCOUNT_DELETION_VERIFICATION,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
"Send account deletion verification code email to {} succeeded: latency: {}".format(
to, end_at - start_at
@@ -63,4 +83,4 @@ def send_account_deletion_verification_code(to, code):
)
)
except Exception:
logging.exception("Send account deletion verification code email to {} failed".format(to))
logger.exception("Send account deletion verification code email to %s failed", to)
+80
View File
@@ -0,0 +1,80 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None:
"""
Send change email notification with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Email verification code
phase: Change email phase ('old_email' or 'new_email')
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start change email mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_change_email(
language_code=language,
to=to,
code=code,
phase=phase,
)
end_at = time.perf_counter()
logger.info(click.style(f"Send change email mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Send change email mail to %s failed", to)
@shared_task(queue="mail")
def send_change_mail_completed_notification_task(language: str, to: str) -> None:
"""
Send change email completed notification with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start change email completed notify mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.CHANGE_EMAIL_COMPLETED,
language_code=language,
to=to,
template_context={
"to": to,
"email": to,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send change email completed mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Send change email completed mail to %s failed", to)
+25 -20
View File
@@ -2,40 +2,45 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from flask import render_template
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_email_code_login_mail_task(language: str, to: str, code: str):
def send_email_code_login_mail_task(language: str, to: str, code: str) -> None:
"""
Async Send email code login mail
:param language: Language in which the email should be sent (e.g., 'en', 'zh')
:param to: Recipient email address
:param code: Email code to be included in the email
Send email code login email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Email verification code
"""
if not mail.is_inited():
return
logging.info(click.style("Start email code login mail to {}".format(to), fg="green"))
logger.info(click.style(f"Start email code login mail to {to}", fg="green"))
start_at = time.perf_counter()
# send email code login mail using different languages
try:
if language == "zh-Hans":
html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code)
mail.send(to=to, subject="邮箱验证码", html=html_content)
else:
html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code)
mail.send(to=to, subject="Email Code", html=html_content)
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.EMAIL_CODE_LOGIN,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
},
)
end_at = time.perf_counter()
logging.info(
click.style(
"Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
)
logger.info(
click.style(f"Send email code login mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logging.exception("Send email code login mail to {} failed".format(to))
logger.exception("Send email code login mail to %s failed", to)
+32
View File
@@ -0,0 +1,32 @@
import logging
import time
from collections.abc import Mapping
import click
from celery import shared_task
from flask import render_template_string
from extensions.ext_mail import mail
from libs.email_i18n import get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_inner_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
if not mail.is_inited():
return
logger.info(click.style(f"Start enterprise mail to {to} with subject {subject}", fg="green"))
start_at = time.perf_counter()
try:
html_content = render_template_string(body, **substitutions)
email_service = get_email_i18n_service()
email_service.send_raw_email(to=to, subject=subject, html_content=html_content)
end_at = time.perf_counter()
logger.info(click.style(f"Send enterprise mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Send enterprise mail to %s failed", to)
+27 -38
View File
@@ -2,60 +2,49 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from flask import render_template
from celery import shared_task
from configs import dify_config
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str):
def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str) -> None:
"""
Async Send invite member mail
:param language
:param to
:param token
:param inviter_name
:param workspace_name
Send invite member email with internationalization support.
Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name)
Args:
language: Language code for email localization
to: Recipient email address
token: Invitation token
inviter_name: Name of the person sending the invitation
workspace_name: Name of the workspace
"""
if not mail.is_inited():
return
logging.info(
click.style("Start send invite member mail to {} in workspace {}".format(to, workspace_name), fg="green")
)
logger.info(click.style(f"Start send invite member mail to {to} in workspace {workspace_name}", fg="green"))
start_at = time.perf_counter()
# send invite member mail using different languages
try:
url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}"
if language == "zh-Hans":
html_content = render_template(
"invite_member_mail_template_zh-CN.html",
to=to,
inviter_name=inviter_name,
workspace_name=workspace_name,
url=url,
)
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
else:
html_content = render_template(
"invite_member_mail_template_en-US.html",
to=to,
inviter_name=inviter_name,
workspace_name=workspace_name,
url=url,
)
mail.send(to=to, subject="Join Dify Workspace Now", html=html_content)
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.INVITE_MEMBER,
language_code=language,
to=to,
template_context={
"to": to,
"inviter_name": inviter_name,
"workspace_name": workspace_name,
"url": url,
},
)
end_at = time.perf_counter()
logging.info(
click.style(
"Send invite member mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
)
)
logger.info(click.style(f"Send invite member mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
logging.exception("Send invite member mail to {} failed".format(to))
logger.exception("Send invite member mail to %s failed", to)
+131
View File
@@ -0,0 +1,131 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str) -> None:
"""
Send owner transfer confirmation email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Verification code
workspace: Workspace name
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start owner transfer confirm mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.OWNER_TRANSFER_CONFIRM,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
"WorkspaceName": workspace,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send owner transfer confirm mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("owner transfer confirm email mail to %s failed", to)
@shared_task(queue="mail")
def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str) -> None:
"""
Send old owner transfer notification email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
workspace: Workspace name
new_owner_email: New owner email address
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start old owner transfer notify mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.OWNER_TRANSFER_OLD_NOTIFY,
language_code=language,
to=to,
template_context={
"to": to,
"WorkspaceName": workspace,
"NewOwnerEmail": new_owner_email,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send old owner transfer notify mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("old owner transfer notify email mail to %s failed", to)
@shared_task(queue="mail")
def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str) -> None:
"""
Send new owner transfer notification email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
workspace: Workspace name
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start new owner transfer notify mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY,
language_code=language,
to=to,
template_context={
"to": to,
"WorkspaceName": workspace,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send new owner transfer notify mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("new owner transfer notify email mail to %s failed", to)
+25 -20
View File
@@ -2,40 +2,45 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from flask import render_template
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_reset_password_mail_task(language: str, to: str, code: str):
def send_reset_password_mail_task(language: str, to: str, code: str) -> None:
"""
Async Send reset password mail
:param language: Language in which the email should be sent (e.g., 'en', 'zh')
:param to: Recipient email address
:param code: Reset password code
Send reset password email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Reset password code
"""
if not mail.is_inited():
return
logging.info(click.style("Start password reset mail to {}".format(to), fg="green"))
logger.info(click.style(f"Start password reset mail to {to}", fg="green"))
start_at = time.perf_counter()
# send reset password mail using different languages
try:
if language == "zh-Hans":
html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code)
mail.send(to=to, subject="设置您的 Dify 密码", html=html_content)
else:
html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code)
mail.send(to=to, subject="Set Your Dify Password", html=html_content)
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
},
)
end_at = time.perf_counter()
logging.info(
click.style(
"Send password reset mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
)
logger.info(
click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logging.exception("Send password reset mail to {} failed".format(to))
logger.exception("Send password reset mail to %s failed", to)
+7 -4
View File
@@ -1,7 +1,7 @@
import json
import logging
from celery import shared_task # type: ignore
from celery import shared_task
from flask import current_app
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
@@ -12,6 +12,8 @@ from extensions.ext_storage import storage
from models.model import Message
from models.workflow import WorkflowRun
logger = logging.getLogger(__name__)
@shared_task(queue="ops_trace")
def process_trace_tasks(file_info):
@@ -43,10 +45,11 @@ def process_trace_tasks(file_info):
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logging.info(f"Processing trace tasks success, app_id: {app_id}")
except Exception:
logger.info("Processing trace tasks success, app_id: %s", app_id)
except Exception as e:
logger.info("error:\n\n\n%s\n\n\n\n", e)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logging.info(f"Processing trace tasks failed, app_id: {app_id}")
logger.info("Processing trace tasks failed, app_id: %s", app_id)
finally:
storage.delete(file_path)
@@ -0,0 +1,166 @@
import traceback
import typing
import click
from celery import shared_task
from core.helper import marketplace
from core.helper.marketplace import MarketplacePluginDeclaration
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from models.account import TenantPluginAutoUpgradeStrategy
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
def marketplace_batch_fetch_plugin_manifests(
plugin_ids_plain_list: list[str],
) -> list[MarketplacePluginDeclaration]:
global cached_plugin_manifests
# return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
not_included_plugin_ids = [
plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests
]
if not_included_plugin_ids:
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids)
for manifest in manifests:
cached_plugin_manifests[manifest.plugin_id] = manifest
if (
len(manifests) == 0
): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check
for plugin_id in not_included_plugin_ids:
cached_plugin_manifests[plugin_id] = None
result: list[MarketplacePluginDeclaration] = []
for plugin_id in plugin_ids_plain_list:
final_manifest = cached_plugin_manifests.get(plugin_id)
if final_manifest is not None:
result.append(final_manifest)
return result
@shared_task(queue="plugin")
def process_tenant_plugin_autoupgrade_check_task(
tenant_id: str,
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
upgrade_time_of_day: int,
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
exclude_plugins: list[str],
include_plugins: list[str],
):
try:
manager = PluginInstaller()
click.echo(
click.style(
f"Checking upgradable plugin for tenant: {tenant_id}",
fg="green",
)
)
if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
return
# get plugin_ids to check
plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
click.echo(click.style(f"Upgrade mode: {upgrade_mode}", fg="green"))
if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
all_plugins = manager.list_plugins(tenant_id)
for plugin in all_plugins:
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
plugin_ids.append(
(
plugin.plugin_id,
plugin.version,
plugin.plugin_unique_identifier,
)
)
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
# get all plugins and remove excluded plugins
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
]
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace
]
if not plugin_ids:
return
plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list)
if not manifests:
return
for manifest in manifests:
for plugin_id, version, original_unique_identifier in plugin_ids:
if manifest.plugin_id != plugin_id:
continue
try:
current_version = version
latest_version = manifest.latest_version
def fix_only_checker(latest_version, current_version):
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split("."))
if (
latest_version_tuple[0] == current_version_tuple[0]
and latest_version_tuple[1] == current_version_tuple[1]
):
return latest_version_tuple[2] != current_version_tuple[2]
return False
version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
current_version: latest_version != current_version,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
}
if version_checker[strategy_setting](latest_version, current_version):
# execute upgrade
new_unique_identifier = manifest.latest_package_identifier
marketplace.record_install_plugin_event(new_unique_identifier)
click.echo(
click.style(
f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}",
fg="green",
)
)
_ = manager.upgrade_plugin(
tenant_id,
original_unique_identifier,
new_unique_identifier,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": new_unique_identifier,
},
)
except Exception as e:
click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red"))
traceback.print_exc()
break
except Exception as e:
click.echo(click.style(f"Error when checking upgradable plugin: {e}", fg="red"))
traceback.print_exc()
return
+9 -9
View File
@@ -2,12 +2,14 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
from models.dataset import Document
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def recover_document_indexing_task(dataset_id: str, document_id: str):
@@ -18,13 +20,13 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
Usage: recover_document_indexing_task.delay(dataset_id, document_id)
"""
logging.info(click.style("Recover document: {}".format(document_id), fg="green"))
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
@@ -37,12 +39,10 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
elif document.indexing_status == "indexing":
indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter()
logging.info(
click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green")
)
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
finally:
db.session.close()
+143 -58
View File
@@ -3,16 +3,19 @@ import time
from collections.abc import Callable
import click
from celery import shared_task # type: ignore
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
from models.model import (
from models import (
ApiToken,
AppAnnotationHitHistory,
AppAnnotationSetting,
AppDatasetJoin,
AppMCPServer,
AppModelConfig,
Conversation,
EndUser,
@@ -30,17 +33,25 @@ from models.model import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
from models.workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
)
from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
@shared_task(queue="app_deletion", bind=True, max_retries=3)
def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
logging.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green"))
logger.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green"))
start_at = time.perf_counter()
try:
# Delete related data
_delete_app_model_configs(tenant_id, app_id)
_delete_app_site(tenant_id, app_id)
_delete_app_mcp_servers(tenant_id, app_id)
_delete_app_api_tokens(tenant_id, app_id)
_delete_installed_apps(tenant_id, app_id)
_delete_recommended_apps(tenant_id, app_id)
@@ -57,22 +68,21 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_end_users(tenant_id, app_id)
_delete_trace_app_configs(tenant_id, app_id)
_delete_conversation_variables(app_id=app_id)
_delete_draft_variables(app_id)
end_at = time.perf_counter()
logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
except SQLAlchemyError as e:
logging.exception(
click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red")
)
logger.exception(click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red"))
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
except Exception as e:
logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red"))
logger.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red"))
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
def _delete_app_model_configs(tenant_id: str, app_id: str):
def del_model_config(model_config_id: str):
db.session.query(AppModelConfig).filter(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@@ -84,23 +94,43 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
def del_site(site_id: str):
db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False)
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site")
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_site,
"site",
)
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(mcp_server_id: str):
db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_mcp_server,
"app mcp server",
)
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(api_token_id: str):
db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False)
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token"
"""select id from api_tokens where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_api_token,
"api token",
)
def _delete_installed_apps(tenant_id: str, app_id: str):
def del_installed_app(installed_app_id: str):
db.session.query(InstalledApp).filter(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -112,7 +142,7 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
def del_recommended_app(recommended_app_id: str):
db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete(
db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
synchronize_session=False
)
@@ -126,9 +156,9 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(annotation_hit_history_id: str):
db.session.query(AppAnnotationHitHistory).filter(
AppAnnotationHitHistory.id == annotation_hit_history_id
).delete(synchronize_session=False)
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from app_annotation_hit_histories where app_id=:app_id limit 1000""",
@@ -138,7 +168,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
)
def del_annotation_setting(annotation_setting_id: str):
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete(
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
@@ -152,7 +182,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def del_dataset_join(dataset_join_id: str):
db.session.query(AppDatasetJoin).filter(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@@ -164,7 +194,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
def del_workflow(workflow_id: str):
db.session.query(Workflow).filter(Workflow.id == workflow_id).delete(synchronize_session=False)
db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -175,34 +205,36 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def del_workflow_run(workflow_run_id: str):
db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False)
"""Delete all workflow runs for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
_delete_records(
"""select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow_run,
"workflow run",
deleted_count = workflow_run_repo.delete_runs_by_app(
tenant_id=tenant_id,
app_id=app_id,
batch_size=1000,
)
logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id)
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def del_workflow_node_execution(workflow_node_execution_id: str):
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete(
synchronize_session=False
)
"""Delete all workflow node executions for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
_delete_records(
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow_node_execution,
"workflow node execution",
deleted_count = node_execution_repo.delete_executions_by_app(
tenant_id=tenant_id,
app_id=app_id,
batch_size=1000,
)
logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id)
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):
db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete(
db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
synchronize_session=False
)
@@ -216,10 +248,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_conversations(tenant_id: str, app_id: str):
def del_conversation(conversation_id: str):
db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete(
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False)
db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@@ -234,33 +266,36 @@ def _delete_conversation_variables(*, app_id: str):
with db.engine.connect() as conn:
conn.execute(stmt)
conn.commit()
logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(message_id: str):
db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete(
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete(
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False)
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete(
db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False)
db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False)
db.session.query(Message).filter(Message.id == message_id).delete()
db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
db.session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message"
"""select id from messages where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_message,
"message",
)
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def del_tool_provider(tool_provider_id: str):
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete(
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
@@ -274,7 +309,7 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def del_tag_binding(tag_binding_id: str):
db.session.query(TagBinding).filter(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@@ -286,7 +321,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
def del_end_user(end_user_id: str):
db.session.query(EndUser).filter(EndUser.id == end_user_id).delete(synchronize_session=False)
db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -298,7 +333,7 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
def del_trace_app_config(trace_app_config_id: str):
db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete(
db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
synchronize_session=False
)
@@ -310,10 +345,60 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
)
def _delete_draft_variables(app_id: str):
"""Delete all workflow draft variables for an app in batches."""
return delete_draft_variables_batch(app_id, batch_size=1000)
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
"""
Delete draft variables for an app in batches.
Args:
app_id: The ID of the app whose draft variables should be deleted
batch_size: Number of records to delete per batch
Returns:
Total number of records deleted
"""
if batch_size <= 0:
raise ValueError("batch_size must be positive")
total_deleted = 0
while True:
with db.engine.begin() as conn:
# Get a batch of draft variable IDs
query_sql = """
SELECT id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
draft_var_ids = [row[0] for row in result]
if not draft_var_ids:
break
# Delete the batch
delete_sql = """
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
batch_deleted = deleted_result.rowcount
total_deleted += batch_deleted
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
logger.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green"))
return total_deleted
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:
rs = conn.execute(db.text(query_sql), params)
rs = conn.execute(sa.text(query_sql), params)
if rs.rowcount == 0:
break
@@ -322,8 +407,8 @@ def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: s
try:
delete_func(record_id)
db.session.commit()
logging.info(click.style(f"Deleted {name} {record_id}", fg="green"))
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
except Exception:
logging.exception(f"Error occurred while deleting {name} {record_id}")
logger.exception("Error occurred while deleting %s %s", name, record_id)
continue
rs.close()
+16 -18
View File
@@ -1,15 +1,17 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Document, DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def remove_document_from_index_task(document_id: str):
@@ -19,21 +21,21 @@ def remove_document_from_index_task(document_id: str):
Usage: remove_document_from_index.delay(document_id)
"""
logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green"))
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).filter(Document.id == document_id).first()
document = db.session.query(Document).where(Document.id == document_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
if document.indexing_status != "completed":
logging.info(click.style("Document is not completed, remove is not allowed: {}".format(document_id), fg="red"))
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
db.session.close()
return
indexing_cache_key = "document_{}_indexing".format(document.id)
indexing_cache_key = f"document_{document.id}_indexing"
try:
dataset = document.dataset
@@ -43,32 +45,28 @@ def remove_document_from_index_task(document_id: str):
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception:
logging.exception(f"clean dataset {dataset.id} from index failed")
logger.exception("clean dataset %s from index failed", dataset.id)
# update segment to disable
db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).update(
db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
{
DocumentSegment.enabled: False,
DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
DocumentSegment.disabled_at: naive_utc_now(),
DocumentSegment.disabled_by: document.disabled_by,
DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(
"Document removed from index: {} latency: {}".format(document.id, end_at - start_at), fg="green"
)
)
logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
except Exception:
logging.exception("remove document from index failed")
logger.exception("remove document from index failed")
if not document.archived:
document.enabled = True
db.session.commit()
+77 -72
View File
@@ -1,17 +1,19 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
@@ -22,81 +24,84 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
Usage: retry_document_indexing_task.delay(dataset_id, document_ids)
"""
documents: list[Document] = []
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
tenant_id = dataset.tenant_id
for document_id in document_ids:
retry_indexing_cache_key = f"document_{document_id}_is_retried"
# check document limit
features = FeatureService.get_features(tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
return
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red"))
db.session.close()
return
for document_id in document_ids:
retry_indexing_cache_key = "document_{}_is_retried".format(document_id)
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
db.session.close()
return
logging.info(click.style("Start retry document: {}".format(document_id), fg="green"))
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(retry_indexing_cache_key)
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception(
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="yellow"))
db.session.close()
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))
redis_client.delete(retry_indexing_cache_key)
pass
finally:
db.session.close()
end_at = time.perf_counter()
logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
raise e
finally:
db.session.close()
@@ -1,17 +1,19 @@
import datetime
import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def sync_website_document_indexing_task(dataset_id: str, document_id: str):
@@ -24,11 +26,11 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
raise ValueError("Dataset not found")
sync_indexing_cache_key = "document_{}_is_sync".format(document_id)
sync_indexing_cache_key = f"document_{document_id}_is_sync"
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
@@ -41,27 +43,27 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
)
except Exception as e:
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
return
logging.info(click.style("Start sync website document: {}".format(document_id), fg="green"))
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="yellow"))
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
@@ -72,7 +74,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
@@ -82,11 +84,11 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(sync_indexing_cache_key)
pass
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logging.info(click.style("Sync document: {} latency: {}".format(document_id, end_at - start_at), fg="green"))
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
+136
View File
@@ -0,0 +1,136 @@
"""
Celery tasks for asynchronous workflow execution storage operations.
These tasks provide asynchronous storage capabilities for workflow execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow execution to the database.
Args:
execution_data: Serialized WorkflowExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
# Check if workflow run already exists
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
if existing_run:
# Update existing workflow run
_update_workflow_run_from_execution(existing_run, execution)
logger.debug("Updated existing workflow run: %s", execution.id_)
else:
# Create new workflow run
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(workflow_run)
logger.debug("Created new workflow run: %s", execution.id_)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_workflow_run_from_execution(
execution: WorkflowExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowRun:
"""
Create a WorkflowRun database model from a WorkflowExecution domain entity.
"""
workflow_run = WorkflowRun()
workflow_run.id = execution.id_
workflow_run.tenant_id = tenant_id
workflow_run.app_id = app_id
workflow_run.workflow_id = execution.workflow_id
workflow_run.type = execution.workflow_type.value
workflow_run.triggered_from = triggered_from.value
workflow_run.version = execution.workflow_version
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.created_by_role = creator_user_role.value
workflow_run.created_by = creator_user_id
workflow_run.created_at = execution.started_at
workflow_run.finished_at = execution.finished_at
return workflow_run
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
"""
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.finished_at = execution.finished_at
+171
View File
@@ -0,0 +1,171 @@
"""
Celery tasks for asynchronous workflow node execution storage operations.
These tasks provide asynchronous storage capabilities for workflow node execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_node_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow node execution to the database.
Args:
execution_data: Serialized WorkflowNodeExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
# Check if node execution already exists
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)
if existing_execution:
# Update existing node execution
_update_node_execution_from_domain(existing_execution, execution)
logger.debug("Updated existing workflow node execution: %s", execution.id)
else:
# Create new node execution
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(node_execution)
logger.debug("Created new workflow node execution: %s", execution.id)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_node_execution_from_domain(
execution: WorkflowNodeExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowNodeExecutionModel:
"""
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
node_execution = WorkflowNodeExecutionModel()
node_execution.id = execution.id
node_execution.tenant_id = tenant_id
node_execution.app_id = app_id
node_execution.workflow_id = execution.workflow_id
node_execution.triggered_from = triggered_from.value
node_execution.workflow_run_id = execution.workflow_execution_id
node_execution.index = execution.index
node_execution.predecessor_node_id = execution.predecessor_node_id
node_execution.node_id = execution.node_id
node_execution.node_type = execution.node_type.value
node_execution.title = execution.title
node_execution.node_execution_id = execution.node_execution_id
# Serialize complex data as JSON
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.created_by_role = creator_user_role.value
node_execution.created_by = creator_user_id
node_execution.created_at = execution.created_at
node_execution.finished_at = execution.finished_at
return node_execution
def _update_node_execution_from_domain(
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
) -> None:
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
# Update serialized data
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
# Update other fields
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.finished_at = execution.finished_at