refactor: use session factory instead of call db.session directly (#31198)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2026-01-21 13:43:06 +08:00
committed by GitHub
parent 071bbc6d74
commit 121d301a41
48 changed files with 2788 additions and 2693 deletions
@@ -4,8 +4,8 @@ from unittest.mock import patch
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from core.variables.segments import StringSegment
from extensions.ext_database import db
from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
@@ -16,26 +16,23 @@ from tasks.remove_app_and_related_data_task import _delete_draft_variables, dele
@pytest.fixture
def app_and_tenant(flask_req_ctx):
tenant_id = uuid.uuid4()
tenant = Tenant(
id=tenant_id,
name="test_tenant",
)
db.session.add(tenant)
with session_factory.create_session() as session:
tenant = Tenant(name="test_tenant")
session.add(tenant)
session.flush()
app = App(
tenant_id=tenant_id, # Now tenant.id will have a value
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.flush()
yield (tenant, app)
app = App(
tenant_id=tenant.id,
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
session.add(app)
session.flush()
# Cleanup with proper error handling
db.session.delete(app)
db.session.delete(tenant)
# return detached objects (ids will be used by tests)
return (tenant, app)
class TestDeleteDraftVariablesIntegration:
@@ -44,334 +41,285 @@ class TestDeleteDraftVariablesIntegration:
"""Create test data with apps and draft variables."""
tenant, app = app_and_tenant
# Create a second app for testing
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.commit()
# Create draft variables for both apps
variables_app1 = []
variables_app2 = []
for i in range(5):
var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
with session_factory.create_session() as session:
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(var1)
variables_app1.append(var1)
session.add(app2)
session.flush()
var2 = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var2)
variables_app2.append(var2)
variables_app1 = []
variables_app2 = []
for i in range(5):
var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var1)
variables_app1.append(var1)
# Commit all the variables to the database
db.session.commit()
var2 = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var2)
variables_app2.append(var2)
session.commit()
app2_id = app2.id
yield {
"app1": app,
"app2": app2,
"app2": App(id=app2_id), # dummy with id to avoid open session
"tenant": tenant,
"variables_app1": variables_app1,
"variables_app2": variables_app2,
}
# Cleanup - refresh session and check if objects still exist
db.session.rollback() # Clear any pending changes
# Clean up remaining variables
cleanup_query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
with session_factory.create_session() as session:
cleanup_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id.in_([app.id, app2_id]))
.execution_options(synchronize_session=False)
)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_query)
# Clean up app2
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()
session.execute(cleanup_query)
app2_obj = session.get(App, app2_id)
if app2_obj:
session.delete(app2_obj)
session.commit()
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
"""Test that batch deletion only removes variables for the specified app."""
data = setup_test_data
app1_id = data["app1"].id
app2_id = data["app2"].id
# Verify initial state
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
with session_factory.create_session() as session:
app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_before == 5
assert app2_vars_before == 5
# Delete app1 variables
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
# Verify results
assert deleted_count == 5
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_after == 0 # All app1 variables deleted
assert app2_vars_after == 5 # App2 variables unchanged
with session_factory.create_session() as session:
app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_after == 0
assert app2_vars_after == 5
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
"""Test batch deletion with small batch size processes all records."""
data = setup_test_data
app1_id = data["app1"].id
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
assert deleted_count == 5
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert remaining_vars == 0
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
"""Test that deleting variables for nonexistent app returns 0."""
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
nonexistent_app_id = str(uuid.uuid4())
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
assert deleted_count == 0
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
"""Test that _delete_draft_variables wrapper function works correctly."""
data = setup_test_data
app1_id = data["app1"].id
# Verify initial state
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_before == 5
# Call wrapper function
deleted_count = _delete_draft_variables(app1_id)
# Verify results
assert deleted_count == 5
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_after == 0
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
"""Test batch deletion with larger dataset to verify batching logic."""
tenant, app = app_and_tenant
# Create many draft variables
variables = []
for i in range(25):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var)
variables.append(var)
variable_ids = [i.id for i in variables]
# Commit the variables to the database
db.session.commit()
variable_ids: list[str] = []
with session_factory.create_session() as session:
variables = []
for i in range(25):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
variables.append(var)
session.commit()
variable_ids = [v.id for v in variables]
try:
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
assert deleted_count == 25
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
assert remaining_vars == 0
with session_factory.create_session() as session:
remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
assert remaining == 0
finally:
query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.id.in_(variable_ids),
with session_factory.create_session() as session:
query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
.execution_options(synchronize_session=False)
)
.execution_options(synchronize_session=False)
)
db.session.execute(query)
session.execute(query)
session.commit()
class TestDeleteDraftVariablesWithOffloadIntegration:
"""Integration tests for draft variable deletion with Offload data."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with draft variables that have associated Offload files."""
tenant, app = app_and_tenant
# Create UploadFile records
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
db.session.add(upload_file1)
db.session.add(upload_file2)
db.session.flush()
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file1)
session.add(upload_file2)
session.flush()
# Create WorkflowDraftVariableFile records
from core.variables.types import SegmentType
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
session.add(var_file1)
session.add(var_file2)
session.flush()
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
db.session.add(var_file1)
db.session.add(var_file2)
db.session.flush()
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(draft_var1)
session.add(draft_var2)
session.add(draft_var3)
session.commit()
# Create WorkflowDraftVariable records with file associations
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
# Create a regular variable without Offload data
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
data = {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
db.session.add(draft_var1)
db.session.add(draft_var2)
db.session.add(draft_var3)
db.session.commit()
yield data
yield {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
# Cleanup
db.session.rollback()
# Clean up any remaining records
for table, ids in [
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
db.session.execute(cleanup_query)
db.session.commit()
with session_factory.create_session() as session:
session.rollback()
for table, ids in [
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
(UploadFile, [uf.id for uf in data["upload_files"]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
session.execute(cleanup_query)
session.commit()
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
"""Test that deleting draft variables also cleans up associated Offload data."""
data = setup_offload_test_data
app_id = data["app"].id
# Mock storage deletion to succeed
mock_storage.delete.return_value = None
# Verify initial state
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
upload_files_before = db.session.query(UploadFile).count()
assert draft_vars_before == 3 # 2 with files + 1 regular
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = session.query(WorkflowDraftVariableFile).count()
upload_files_before = session.query(UploadFile).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
# Delete draft variables
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
# Verify results
assert deleted_count == 3
# Check that all draft variables are deleted
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert draft_vars_after == 0
# Check that associated Offload data is cleaned up
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
upload_files_after = db.session.query(UploadFile).count()
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
assert var_files_after == 0
assert upload_files_after == 0
assert var_files_after == 0 # All variable files should be deleted
assert upload_files_after == 0 # All upload files should be deleted
# Verify storage deletion was called for both files
assert mock_storage.delete.call_count == 2
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
assert "test/file1.json" in storage_keys_deleted
@@ -379,92 +327,71 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that database cleanup continues even when storage deletion fails."""
data = setup_offload_test_data
app_id = data["app"].id
# Mock storage deletion to fail for first file, succeed for second
mock_storage.delete.side_effect = [Exception("Storage error"), None]
# Delete draft variables
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
# Verify that all draft variables are still deleted
assert deleted_count == 3
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert draft_vars_after == 0
# Database cleanup should still succeed even with storage errors
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
upload_files_after = db.session.query(UploadFile).count()
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage deletion was attempted for both files
assert mock_storage.delete.call_count == 2
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
"""Test deletion with mix of variables with and without Offload data."""
data = setup_offload_test_data
app_id = data["app"].id
# Create additional app with only regular variables (no offload data)
tenant = data["tenant"]
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.flush()
# Add regular variables to app2
regular_vars = []
for i in range(3):
var = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
with session_factory.create_session() as session:
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(var)
regular_vars.append(var)
db.session.commit()
session.add(app2)
session.flush()
for i in range(3):
var = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
session.commit()
try:
# Mock storage deletion
mock_storage.delete.return_value = None
# Delete variables for app2 (no offload data)
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
assert deleted_count_app2 == 3
# Verify storage wasn't called for app2 (no offload files)
mock_storage.delete.assert_not_called()
# Delete variables for original app (with offload data)
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count_app1 == 3
# Now storage should be called for the offload files
assert mock_storage.delete.call_count == 2
finally:
# Cleanup app2 and its variables
cleanup_vars_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app2.id)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_vars_query)
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()
with session_factory.create_session() as session:
cleanup_vars_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app2.id)
.execution_options(synchronize_session=False)
)
session.execute(cleanup_vars_query)
app2_obj = session.get(App, app2.id)
if app2_obj:
session.delete(app2_obj)
session.commit()
@@ -39,23 +39,22 @@ class TestCleanDatasetTask:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database before each test to ensure isolation."""
from extensions.ext_database import db
from extensions.ext_redis import redis_client
# Clear all test data
db.session.query(DatasetMetadataBinding).delete()
db.session.query(DatasetMetadata).delete()
db.session.query(AppDatasetJoin).delete()
db.session.query(DatasetQuery).delete()
db.session.query(DatasetProcessRule).delete()
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(UploadFile).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
# Clear all test data using the provided session fixture
db_session_with_containers.query(DatasetMetadataBinding).delete()
db_session_with_containers.query(DatasetMetadata).delete()
db_session_with_containers.query(AppDatasetJoin).delete()
db_session_with_containers.query(DatasetQuery).delete()
db_session_with_containers.query(DatasetProcessRule).delete()
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@@ -103,10 +102,8 @@ class TestCleanDatasetTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@@ -115,8 +112,8 @@ class TestCleanDatasetTask:
status="active",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account relationship
tenant_account_join = TenantAccountJoin(
@@ -125,8 +122,8 @@ class TestCleanDatasetTask:
role=TenantAccountRole.OWNER,
)
db.session.add(tenant_account_join)
db.session.commit()
db_session_with_containers.add(tenant_account_join)
db_session_with_containers.commit()
return account, tenant
@@ -155,10 +152,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@@ -194,10 +189,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@@ -232,10 +225,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segment
@@ -267,10 +258,8 @@ class TestCleanDatasetTask:
used=False,
)
from extensions.ext_database import db
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
@@ -302,31 +291,29 @@ class TestCleanDatasetTask:
)
# Verify results
from extensions.ext_database import db
# Check that dataset-related data was cleaned up
documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(documents) == 0
segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(segments) == 0
# Check that metadata and bindings were cleaned up
metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(metadata) == 0
bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
assert len(bindings) == 0
# Check that process rules and queries were cleaned up
process_rules = db.session.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
assert len(process_rules) == 0
queries = db.session.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
assert len(queries) == 0
# Check that app dataset joins were cleaned up
app_joins = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
assert len(app_joins) == 0
# Verify index processor was called
@@ -378,9 +365,7 @@ class TestCleanDatasetTask:
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Create dataset metadata and bindings
metadata = DatasetMetadata(
@@ -403,11 +388,9 @@ class TestCleanDatasetTask:
binding.id = str(uuid.uuid4())
binding.created_at = datetime.now()
from extensions.ext_database import db
db.session.add(metadata)
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(metadata)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
# Execute the task
clean_dataset_task(
@@ -421,22 +404,24 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
assert len(remaining_files) == 0
# Check that metadata and bindings were cleaned up
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_bindings) == 0
# Verify index processor was called
@@ -489,12 +474,13 @@ class TestCleanDatasetTask:
mock_index_processor.clean.assert_called_once()
# Check that all data was cleaned up
from extensions.ext_database import db
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = (
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_segments) == 0
# Recreate data for next test case
@@ -540,14 +526,13 @@ class TestCleanDatasetTask:
)
# Verify results - even with vector cleanup failure, documents and segments should be deleted
from extensions.ext_database import db
# Check that documents were still deleted despite vector cleanup failure
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite vector cleanup failure
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Verify that index processor was called and failed
@@ -608,10 +593,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Mock the get_image_upload_file_ids function to return our image file IDs
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids:
@@ -629,16 +612,18 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all image files were deleted from database
image_file_ids = [f.id for f in image_files]
remaining_image_files = db.session.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
remaining_image_files = (
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
)
assert len(remaining_image_files) == 0
# Verify that storage.delete was called for each image file
@@ -745,22 +730,24 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
assert len(remaining_files) == 0
# Check that all metadata and bindings were deleted
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_bindings) == 0
# Verify performance expectations
@@ -808,9 +795,7 @@ class TestCleanDatasetTask:
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock storage to raise exceptions
mock_storage = mock_external_service_dependencies["storage"]
@@ -827,18 +812,13 @@ class TestCleanDatasetTask:
)
# Verify results
# Check that documents were still deleted despite storage failure
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite storage failure
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Note: When storage operations fail, database deletions may be rolled back by implementation.
# This test focuses on ensuring the task handles the exception and continues execution/logging.
# Check that upload file was still deleted from database despite storage failure
# Note: When storage operations fail, the upload file may not be deleted
# This demonstrates that the cleanup process continues even with storage errors
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file.id).all()
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
# The upload file should still be deleted from the database even if storage cleanup fails
# However, this depends on the specific implementation of clean_dataset_task
if len(remaining_files) > 0:
@@ -890,10 +870,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create document with special characters in name
special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~"
@@ -912,8 +890,8 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
updated_at=datetime.now(),
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Create segment with special characters and very long content
long_content = "Very long content " * 100 # Long content within reasonable limits
@@ -934,8 +912,8 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
updated_at=datetime.now(),
)
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Create upload file with special characters in name
special_filename = f"test_file_{special_content}.txt"
@@ -952,14 +930,14 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
used=False,
)
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
# Update document with file reference
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
db_session_with_containers.commit()
# Save upload file ID for verification
upload_file_id = upload_file.id
@@ -975,8 +953,8 @@ class TestCleanDatasetTask:
special_metadata.id = str(uuid.uuid4())
special_metadata.created_at = datetime.now()
db.session.add(special_metadata)
db.session.commit()
db_session_with_containers.add(special_metadata)
db_session_with_containers.commit()
# Execute the task
clean_dataset_task(
@@ -990,19 +968,19 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all()
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
assert len(remaining_files) == 0
# Check that all metadata was deleted
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
# Verify that storage.delete was called
@@ -24,16 +24,15 @@ class TestCreateSegmentToIndexTask:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database and Redis before each test to ensure isolation."""
from extensions.ext_database import db
# Clear all test data
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
# Clear all test data using fixture session
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@@ -73,10 +72,8 @@ class TestCreateSegmentToIndexTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@@ -84,8 +81,8 @@ class TestCreateSegmentToIndexTask:
status="normal",
plan="basic",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join with owner role
join = TenantAccountJoin(
@@ -94,8 +91,8 @@ class TestCreateSegmentToIndexTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@@ -746,20 +743,9 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
)
# Mock global database session to simulate transaction issues
from extensions.ext_database import db
original_commit = db.session.commit
commit_called = False
def mock_commit():
nonlocal commit_called
if not commit_called:
commit_called = True
raise Exception("Database commit failed")
return original_commit()
db.session.commit = mock_commit
# Simulate an error during indexing to trigger rollback path
mock_processor = mock_external_service_dependencies["index_processor"]
mock_processor.load.side_effect = Exception("Simulated indexing error")
# Act: Execute the task
create_segment_to_index_task(segment.id)
@@ -771,9 +757,6 @@ class TestCreateSegmentToIndexTask:
assert segment.disabled_at is not None
assert segment.error is not None
# Restore original commit method
db.session.commit = original_commit
def test_create_segment_to_index_metadata_validation(
self, db_session_with_containers, mock_external_service_dependencies
):
@@ -70,11 +70,9 @@ class TestDisableSegmentsFromIndexTask:
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
from extensions.ext_database import db
db.session.add(tenant)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Set the current tenant for the account
account.current_tenant = tenant
@@ -110,10 +108,8 @@ class TestDisableSegmentsFromIndexTask:
built_in_field_enabled=False,
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@@ -158,10 +154,8 @@ class TestDisableSegmentsFromIndexTask:
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_language = "en"
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@@ -211,11 +205,9 @@ class TestDisableSegmentsFromIndexTask:
segments.append(segment)
from extensions.ext_database import db
for segment in segments:
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segments
@@ -645,15 +637,12 @@ class TestDisableSegmentsFromIndexTask:
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Mock db.session.close to verify it's called
with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify session was closed
mock_close.assert_called()
# Assert
assert result is None # Task should complete without returning a value
# Session lifecycle is managed by context manager; no explicit close assertion
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
"""
@@ -6,7 +6,6 @@ from faker import Faker
from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import (
@@ -75,15 +74,15 @@ class TestDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@@ -92,8 +91,8 @@ class TestDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@@ -105,8 +104,8 @@ class TestDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@@ -124,13 +123,13 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@@ -157,15 +156,15 @@ class TestDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@@ -174,8 +173,8 @@ class TestDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@@ -187,8 +186,8 @@ class TestDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@@ -206,10 +205,10 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
@@ -219,7 +218,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@@ -242,6 +241,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@@ -250,7 +252,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -310,6 +312,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with mixed document IDs
_document_indexing(dataset.id, all_document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@@ -317,7 +322,7 @@ class TestDocumentIndexingTasks:
# Verify only existing documents were updated
# Re-query documents from database since _document_indexing uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -353,6 +358,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@@ -361,7 +369,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -400,7 +408,7 @@ class TestDocumentIndexingTasks:
indexing_status="completed", # Already completed
enabled=True,
)
db.session.add(doc1)
db_session_with_containers.add(doc1)
extra_documents.append(doc1)
# Document with disabled status
@@ -417,10 +425,10 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=False, # Disabled
)
db.session.add(doc2)
db_session_with_containers.add(doc2)
extra_documents.append(doc2)
db.session.commit()
db_session_with_containers.commit()
all_documents = base_documents + extra_documents
document_ids = [doc.id for doc in all_documents]
@@ -428,6 +436,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with mixed document states
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@@ -435,7 +446,7 @@ class TestDocumentIndexingTasks:
# Verify all documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -482,20 +493,23 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
extra_documents.append(document)
db.session.commit()
db_session_with_containers.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error
@@ -526,6 +540,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with billing disabled
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@@ -533,7 +550,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -565,6 +582,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@@ -573,7 +593,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -674,6 +694,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify core processing occurred (same as _document_indexing)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@@ -681,7 +704,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated (same as _document_indexing)
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -794,6 +817,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@@ -802,7 +828,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -865,6 +891,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function for tenant1 only
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify core processing occurred for tenant1
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@@ -4,7 +4,6 @@ import pytest
from faker import Faker
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.duplicate_document_indexing_task import (
@@ -82,15 +81,15 @@ class TestDuplicateDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@@ -99,8 +98,8 @@ class TestDuplicateDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@@ -112,8 +111,8 @@ class TestDuplicateDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@@ -132,13 +131,13 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@@ -183,14 +182,14 @@ class TestDuplicateDocumentIndexingTasks:
indexing_at=fake.date_time_this_year(),
created_by=dataset.created_by, # Add required field
)
db.session.add(segment)
db_session_with_containers.add(segment)
segments.append(segment)
db.session.commit()
db_session_with_containers.commit()
# Refresh to ensure all relationships are loaded
for document in documents:
db.session.refresh(document)
db_session_with_containers.refresh(document)
return dataset, documents, segments
@@ -217,15 +216,15 @@ class TestDuplicateDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@@ -234,8 +233,8 @@ class TestDuplicateDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@@ -247,8 +246,8 @@ class TestDuplicateDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@@ -267,10 +266,10 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
@@ -280,7 +279,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@@ -305,6 +304,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@@ -313,7 +315,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -340,23 +342,32 @@ class TestDuplicateDocumentIndexingTasks:
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
)
document_ids = [doc.id for doc in documents]
segment_ids = [seg.id for seg in segments]
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify segment cleanup
db_session_with_containers.expire_all()
# Assert: Verify segment cleanup
# Verify index processor clean was called for each document with segments
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
# Verify segments were deleted from database
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
for segment in segments:
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
# Re-query segments from database using captured IDs to avoid stale ORM instances
for seg_id in segment_ids:
deleted_segment = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first()
)
assert deleted_segment is None
# Verify documents were updated to parsing status
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -415,6 +426,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task with mixed document IDs
_duplicate_document_indexing_task(dataset.id, all_document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@@ -422,7 +436,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify only existing documents were updated
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -458,6 +472,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@@ -466,7 +483,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _duplicate_document_indexing_task close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@@ -508,20 +525,23 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
extra_documents.append(document)
db.session.commit()
db_session_with_containers.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error.lower()
@@ -557,10 +577,13 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task with documents that will exceed vector space limit
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "limit" in updated_document.error.lower()
@@ -620,11 +643,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@@ -663,11 +686,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@@ -707,11 +730,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@@ -49,10 +49,14 @@ def pipeline_id():
@pytest.fixture
def mock_db_session():
"""Mock database session with query capabilities."""
with patch("tasks.clean_dataset_task.db") as mock_db:
"""Mock database session via session_factory.create_session()."""
with patch("tasks.clean_dataset_task.session_factory") as mock_sf:
mock_session = MagicMock()
mock_db.session = mock_session
# context manager for create_session()
cm = MagicMock()
cm.__enter__.return_value = mock_session
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
# Setup query chain
mock_query = MagicMock()
@@ -66,7 +70,10 @@ def mock_db_session():
# Setup execute for JOIN queries
mock_session.execute.return_value.all.return_value = []
yield mock_db
# Yield an object with a `.session` attribute to keep tests unchanged
wrapper = MagicMock()
wrapper.session = mock_session
yield wrapper
@pytest.fixture
@@ -227,7 +234,9 @@ class TestBasicCleanup:
# Assert
mock_db_session.session.delete.assert_any_call(mock_document)
mock_db_session.session.delete.assert_any_call(mock_segment)
# Segments are deleted in batch; verify a DELETE on document_segments was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
mock_db_session.session.commit.assert_called_once()
def test_clean_dataset_task_deletes_related_records(
@@ -413,7 +422,9 @@ class TestErrorHandling:
# Assert - documents and segments should still be deleted
mock_db_session.session.delete.assert_any_call(mock_document)
mock_db_session.session.delete.assert_any_call(mock_segment)
# Segments are deleted in batch; verify a DELETE on document_segments was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
mock_db_session.session.commit.assert_called_once()
def test_clean_dataset_task_storage_delete_failure_continues(
@@ -461,7 +472,7 @@ class TestErrorHandling:
[mock_segment], # segments
]
mock_get_image_upload_file_ids.return_value = [image_file_id]
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
mock_storage.delete.side_effect = Exception("Storage service unavailable")
# Act
@@ -476,8 +487,9 @@ class TestErrorHandling:
# Assert - storage delete was attempted for image file
mock_storage.delete.assert_called_with(mock_upload_file.key)
# Image file should still be deleted from database
mock_db_session.session.delete.assert_any_call(mock_upload_file)
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
def test_clean_dataset_task_database_error_rollback(
self,
@@ -691,8 +703,10 @@ class TestSegmentAttachmentCleanup:
# Assert
mock_storage.delete.assert_called_with(mock_attachment_file.key)
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
mock_db_session.session.delete.assert_any_call(mock_binding)
# Attachment file and binding are deleted in batch; verify DELETEs were issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
def test_clean_dataset_task_attachment_storage_failure(
self,
@@ -734,9 +748,10 @@ class TestSegmentAttachmentCleanup:
# Assert - storage delete was attempted
mock_storage.delete.assert_called_once()
# Records should still be deleted from database
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
mock_db_session.session.delete.assert_any_call(mock_binding)
# Records are deleted in batch; verify DELETEs were issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
# ============================================================================
@@ -784,7 +799,7 @@ class TestUploadFileCleanup:
[mock_document], # documents
[], # segments
]
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
# Act
clean_dataset_task(
@@ -798,7 +813,9 @@ class TestUploadFileCleanup:
# Assert
mock_storage.delete.assert_called_with(mock_upload_file.key)
mock_db_session.session.delete.assert_any_call(mock_upload_file)
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
def test_clean_dataset_task_handles_missing_upload_file(
self,
@@ -832,7 +849,7 @@ class TestUploadFileCleanup:
[mock_document], # documents
[], # segments
]
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
# Act - should not raise exception
clean_dataset_task(
@@ -949,11 +966,11 @@ class TestImageFileCleanup:
[mock_segment], # segments
]
# Setup a mock query chain that returns files in sequence
# Setup a mock query chain that returns files in batch (align with .in_().all())
mock_query = MagicMock()
mock_where = MagicMock()
mock_query.where.return_value = mock_where
mock_where.first.side_effect = mock_image_files
mock_where.all.return_value = mock_image_files
mock_db_session.session.query.return_value = mock_query
# Act
@@ -966,10 +983,10 @@ class TestImageFileCleanup:
doc_form="paragraph_index",
)
# Assert
assert mock_storage.delete.call_count == 2
mock_storage.delete.assert_any_call("images/image-1.jpg")
mock_storage.delete.assert_any_call("images/image-2.jpg")
# Assert - each expected image key was deleted at least once
calls = [c.args[0] for c in mock_storage.delete.call_args_list]
assert "images/image-1.jpg" in calls
assert "images/image-2.jpg" in calls
def test_clean_dataset_task_handles_missing_image_file(
self,
@@ -1010,7 +1027,7 @@ class TestImageFileCleanup:
]
# Image file not found
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
# Act - should not raise exception
clean_dataset_task(
@@ -1086,14 +1103,15 @@ class TestEdgeCases:
doc_form="paragraph_index",
)
# Assert - all documents and segments should be deleted
# Assert - all documents and segments should be deleted (documents per-entity, segments in batch)
delete_calls = mock_db_session.session.delete.call_args_list
deleted_items = [call[0][0] for call in delete_calls]
for doc in mock_documents:
assert doc in deleted_items
for seg in mock_segments:
assert seg in deleted_items
# Verify a batch DELETE on document_segments occurred
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
def test_clean_dataset_task_document_with_empty_data_source_info(
self,
@@ -81,12 +81,25 @@ def mock_documents(document_ids, dataset_id):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests that expect session.close() to be called can observe it via the context manager
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
@pytest.fixture
@@ -18,12 +18,18 @@ from tasks.delete_account_task import delete_account_task
@pytest.fixture
def mock_db_session():
"""Mock the db.session used in delete_account_task."""
with patch("tasks.delete_account_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
yield mock_session
"""Mock session via session_factory.create_session()."""
with patch("tasks.delete_account_task.session_factory") as mock_sf:
session = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
@pytest.fixture
@@ -109,13 +109,25 @@ def mock_document_segments(document_id):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
@pytest.fixture
@@ -251,8 +263,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# No session operations should be performed beyond the initial query
mock_db_session.close.assert_not_called()
# Session should still be closed via context manager teardown
assert mock_db_session.close.called
def test_successful_sync_when_page_updated(
self,
@@ -286,9 +298,9 @@ class TestDocumentIndexingSyncTask:
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])
@@ -94,13 +94,25 @@ def mock_document_segments(document_ids):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Allow tests to observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
@pytest.fixture
@@ -200,8 +212,25 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test successful duplicate document indexing flow."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# Dataset via query.first()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# scalars() call sequence:
# 1) documents list
# 2..N) segments per document
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
# First call returns documents; subsequent calls return segments
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = mock_document_segments
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
@@ -264,8 +293,21 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when billing limit is exceeded."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# First scalars() -> documents; subsequent -> empty segments
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_features = mock_feature_service.get_features.return_value
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.TEAM
@@ -294,8 +336,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when IndexingRunner raises an error."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
@@ -318,8 +372,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when document is paused."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
@@ -343,8 +409,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test that duplicate document indexing cleans old segments."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = mock_document_segments
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
# Act
@@ -354,9 +432,9 @@ class TestDuplicateDocumentIndexingTaskCore:
# Verify clean was called for each document
assert mock_processor.clean.call_count == len(mock_documents)
# Verify segments were deleted
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify segments were deleted in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# ============================================================================
@@ -11,21 +11,18 @@ from tasks.remove_app_and_related_data_task import (
class TestDeleteDraftVariablesBatch:
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup):
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup):
"""Test successful deletion of draft variables in batches."""
app_id = "test-app-id"
batch_size = 100
# Mock database connection and engine
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock two batches of results, then empty
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
@@ -68,7 +65,7 @@ class TestDeleteDraftVariablesBatch:
select_result3.__iter__.return_value = iter([])
# Configure side effects in the correct order
mock_conn.execute.side_effect = [
mock_session.execute.side_effect = [
select_result1, # First SELECT
delete_result1, # First DELETE
select_result2, # Second SELECT
@@ -86,54 +83,49 @@ class TestDeleteDraftVariablesBatch:
assert result == 150
# Verify database calls
assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes
assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes
# Verify offload cleanup was called for both batches with file_ids
expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)]
expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)]
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
# Simplified verification - check that the right number of calls were made
# and that the SQL queries contain the expected patterns
actual_calls = mock_conn.execute.call_args_list
actual_calls = mock_session.execute.call_args_list
for i, actual_call in enumerate(actual_calls):
sql_text = str(actual_call[0][0])
normalized = " ".join(sql_text.split())
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
# Verify it's a SELECT query that now includes file_id
sql_text = str(actual_call[0][0])
assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text
assert "WHERE app_id = :app_id" in sql_text
assert "LIMIT :batch_size" in sql_text
assert "SELECT id, file_id FROM workflow_draft_variables" in normalized
assert "WHERE app_id = :app_id" in normalized
assert "LIMIT :batch_size" in normalized
else: # DELETE calls (odd indices: 1, 3)
# Verify it's a DELETE query
sql_text = str(actual_call[0][0])
assert "DELETE FROM workflow_draft_variables" in sql_text
assert "WHERE id IN :ids" in sql_text
assert "DELETE FROM workflow_draft_variables" in normalized
assert "WHERE id IN :ids" in normalized
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup):
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup):
"""Test deletion when no draft variables exist for the app."""
app_id = "nonexistent-app-id"
batch_size = 1000
# Mock database connection
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock empty result
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_conn.execute.return_value = empty_result
mock_session.execute.return_value = empty_result
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 0
assert mock_conn.execute.call_count == 1 # Only one select query
assert mock_session.execute.call_count == 1 # Only one select query
mock_offload_cleanup.assert_not_called() # No files to clean up
def test_delete_draft_variables_batch_invalid_batch_size(self):
@@ -147,22 +139,19 @@ class TestDeleteDraftVariablesBatch:
delete_draft_variables_batch(app_id, 0)
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
@patch("tasks.remove_app_and_related_data_task.session_factory")
@patch("tasks.remove_app_and_related_data_task.logger")
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup):
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup):
"""Test that batch deletion logs progress correctly."""
app_id = "test-app-id"
batch_size = 50
# Mock database
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock one batch then empty
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
@@ -183,7 +172,7 @@ class TestDeleteDraftVariablesBatch:
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_conn.execute.side_effect = [
mock_session.execute.side_effect = [
# Select query result
select_result,
# Delete query result
@@ -201,7 +190,7 @@ class TestDeleteDraftVariablesBatch:
# Verify offload cleanup was called with file_ids
if batch_file_ids:
mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids)
mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids)
# Verify logging calls
assert mock_logging.info.call_count == 2
@@ -261,19 +250,19 @@ class TestDeleteDraftVariableOffloadData:
actual_calls = mock_conn.execute.call_args_list
# First call should be the SELECT query
select_call_sql = str(actual_calls[0][0][0])
select_call_sql = " ".join(str(actual_calls[0][0][0]).split())
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
# Second call should be DELETE upload_files
delete_upload_call_sql = str(actual_calls[1][0][0])
delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split())
assert "DELETE FROM upload_files" in delete_upload_call_sql
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
# Third call should be DELETE workflow_draft_variable_files
delete_variable_files_call_sql = str(actual_calls[2][0][0])
delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split())
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql