mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-14 20:41:21 +08:00
refactor: use session factory instead of call db.session directly (#31198)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,8 +4,8 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
@@ -16,26 +16,23 @@ from tasks.remove_app_and_related_data_task import _delete_draft_variables, dele
|
||||
@pytest.fixture
|
||||
def app_and_tenant(flask_req_ctx):
|
||||
tenant_id = uuid.uuid4()
|
||||
tenant = Tenant(
|
||||
id=tenant_id,
|
||||
name="test_tenant",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
with session_factory.create_session() as session:
|
||||
tenant = Tenant(name="test_tenant")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant_id, # Now tenant.id will have a value
|
||||
name=f"Test App for tenant {tenant.id}",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.flush()
|
||||
yield (tenant, app)
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"Test App for tenant {tenant.id}",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
# Cleanup with proper error handling
|
||||
db.session.delete(app)
|
||||
db.session.delete(tenant)
|
||||
# return detached objects (ids will be used by tests)
|
||||
return (tenant, app)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesIntegration:
|
||||
@@ -44,334 +41,285 @@ class TestDeleteDraftVariablesIntegration:
|
||||
"""Create test data with apps and draft variables."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create a second app for testing
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.commit()
|
||||
|
||||
# Create draft variables for both apps
|
||||
variables_app1 = []
|
||||
variables_app2 = []
|
||||
|
||||
for i in range(5):
|
||||
var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
with session_factory.create_session() as session:
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(var1)
|
||||
variables_app1.append(var1)
|
||||
session.add(app2)
|
||||
session.flush()
|
||||
|
||||
var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var2)
|
||||
variables_app2.append(var2)
|
||||
variables_app1 = []
|
||||
variables_app2 = []
|
||||
for i in range(5):
|
||||
var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var1)
|
||||
variables_app1.append(var1)
|
||||
|
||||
# Commit all the variables to the database
|
||||
db.session.commit()
|
||||
var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var2)
|
||||
variables_app2.append(var2)
|
||||
session.commit()
|
||||
|
||||
app2_id = app2.id
|
||||
|
||||
yield {
|
||||
"app1": app,
|
||||
"app2": app2,
|
||||
"app2": App(id=app2_id), # dummy with id to avoid open session
|
||||
"tenant": tenant,
|
||||
"variables_app1": variables_app1,
|
||||
"variables_app2": variables_app2,
|
||||
}
|
||||
|
||||
# Cleanup - refresh session and check if objects still exist
|
||||
db.session.rollback() # Clear any pending changes
|
||||
|
||||
# Clean up remaining variables
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
|
||||
with session_factory.create_session() as session:
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id.in_([app.id, app2_id]))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
# Clean up app2
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
|
||||
db.session.commit()
|
||||
session.execute(cleanup_query)
|
||||
app2_obj = session.get(App, app2_id)
|
||||
if app2_obj:
|
||||
session.delete(app2_obj)
|
||||
session.commit()
|
||||
|
||||
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
|
||||
"""Test that batch deletion only removes variables for the specified app."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
app2_id = data["app2"].id
|
||||
|
||||
# Verify initial state
|
||||
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
assert app1_vars_before == 5
|
||||
assert app2_vars_before == 5
|
||||
|
||||
# Delete app1 variables
|
||||
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 5
|
||||
|
||||
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
|
||||
assert app1_vars_after == 0 # All app1 variables deleted
|
||||
assert app2_vars_after == 5 # App2 variables unchanged
|
||||
with session_factory.create_session() as session:
|
||||
app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
assert app1_vars_after == 0
|
||||
assert app2_vars_after == 5
|
||||
|
||||
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
|
||||
"""Test batch deletion with small batch size processes all records."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
|
||||
# Use small batch size to force multiple batches
|
||||
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
|
||||
|
||||
assert deleted_count == 5
|
||||
|
||||
# Verify all variables are deleted
|
||||
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert remaining_vars == 0
|
||||
|
||||
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
|
||||
"""Test that deleting variables for nonexistent app returns 0."""
|
||||
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
|
||||
|
||||
nonexistent_app_id = str(uuid.uuid4())
|
||||
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
|
||||
|
||||
assert deleted_count == 0
|
||||
|
||||
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
|
||||
"""Test that _delete_draft_variables wrapper function works correctly."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
|
||||
# Verify initial state
|
||||
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert vars_before == 5
|
||||
|
||||
# Call wrapper function
|
||||
deleted_count = _delete_draft_variables(app1_id)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 5
|
||||
|
||||
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert vars_after == 0
|
||||
|
||||
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
|
||||
"""Test batch deletion with larger dataset to verify batching logic."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create many draft variables
|
||||
variables = []
|
||||
for i in range(25):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var)
|
||||
variables.append(var)
|
||||
variable_ids = [i.id for i in variables]
|
||||
|
||||
# Commit the variables to the database
|
||||
db.session.commit()
|
||||
variable_ids: list[str] = []
|
||||
with session_factory.create_session() as session:
|
||||
variables = []
|
||||
for i in range(25):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var)
|
||||
variables.append(var)
|
||||
session.commit()
|
||||
variable_ids = [v.id for v in variables]
|
||||
|
||||
try:
|
||||
# Use small batch size to force multiple batches
|
||||
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
|
||||
|
||||
assert deleted_count == 25
|
||||
|
||||
# Verify all variables are deleted
|
||||
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
|
||||
assert remaining_vars == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
|
||||
assert remaining == 0
|
||||
finally:
|
||||
query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.id.in_(variable_ids),
|
||||
with session_factory.create_session() as session:
|
||||
query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.id.in_(variable_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(query)
|
||||
session.execute(query)
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
"""Integration tests for draft variable deletion with Offload data."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with draft variables that have associated Offload files."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create UploadFile records
|
||||
from core.variables.types import SegmentType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
db.session.add(upload_file1)
|
||||
db.session.add(upload_file2)
|
||||
db.session.flush()
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
session.add(upload_file1)
|
||||
session.add(upload_file2)
|
||||
session.flush()
|
||||
|
||||
# Create WorkflowDraftVariableFile records
|
||||
from core.variables.types import SegmentType
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
session.add(var_file1)
|
||||
session.add(var_file2)
|
||||
session.flush()
|
||||
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
db.session.add(var_file1)
|
||||
db.session.add(var_file2)
|
||||
db.session.flush()
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(draft_var1)
|
||||
session.add(draft_var2)
|
||||
session.add(draft_var3)
|
||||
session.commit()
|
||||
|
||||
# Create WorkflowDraftVariable records with file associations
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
# Create a regular variable without Offload data
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
data = {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
db.session.add(draft_var1)
|
||||
db.session.add(draft_var2)
|
||||
db.session.add(draft_var3)
|
||||
db.session.commit()
|
||||
yield data
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
db.session.rollback()
|
||||
|
||||
# Clean up any remaining records
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
|
||||
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
db.session.commit()
|
||||
with session_factory.create_session() as session:
|
||||
session.rollback()
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
|
||||
(UploadFile, [uf.id for uf in data["upload_files"]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
session.execute(cleanup_query)
|
||||
session.commit()
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that deleting draft variables also cleans up associated Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to succeed
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = db.session.query(UploadFile).count()
|
||||
|
||||
assert draft_vars_before == 3 # 2 with files + 1 regular
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = session.query(UploadFile).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 3
|
||||
|
||||
# Check that all draft variables are deleted
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Check that associated Offload data is cleaned up
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
assert var_files_after == 0 # All variable files should be deleted
|
||||
assert upload_files_after == 0 # All upload files should be deleted
|
||||
|
||||
# Verify storage deletion was called for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert "test/file1.json" in storage_keys_deleted
|
||||
@@ -379,92 +327,71 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that database cleanup continues even when storage deletion fails."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to fail for first file, succeed for second
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify that all draft variables are still deleted
|
||||
assert deleted_count == 3
|
||||
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Database cleanup should still succeed even with storage errors
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage deletion was attempted for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test deletion with mix of variables with and without Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Create additional app with only regular variables (no offload data)
|
||||
tenant = data["tenant"]
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.flush()
|
||||
|
||||
# Add regular variables to app2
|
||||
regular_vars = []
|
||||
for i in range(3):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
with session_factory.create_session() as session:
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(var)
|
||||
regular_vars.append(var)
|
||||
db.session.commit()
|
||||
session.add(app2)
|
||||
session.flush()
|
||||
|
||||
for i in range(3):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var)
|
||||
session.commit()
|
||||
|
||||
try:
|
||||
# Mock storage deletion
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Delete variables for app2 (no offload data)
|
||||
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
|
||||
assert deleted_count_app2 == 3
|
||||
|
||||
# Verify storage wasn't called for app2 (no offload files)
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
# Delete variables for original app (with offload data)
|
||||
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count_app1 == 3
|
||||
|
||||
# Now storage should be called for the offload files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
finally:
|
||||
# Cleanup app2 and its variables
|
||||
cleanup_vars_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app2.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_vars_query)
|
||||
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
db.session.commit()
|
||||
with session_factory.create_session() as session:
|
||||
cleanup_vars_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app2.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(cleanup_vars_query)
|
||||
app2_obj = session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
session.delete(app2_obj)
|
||||
session.commit()
|
||||
|
||||
@@ -39,23 +39,22 @@ class TestCleanDatasetTask:
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
"""Clean up database before each test to ensure isolation."""
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Clear all test data
|
||||
db.session.query(DatasetMetadataBinding).delete()
|
||||
db.session.query(DatasetMetadata).delete()
|
||||
db.session.query(AppDatasetJoin).delete()
|
||||
db.session.query(DatasetQuery).delete()
|
||||
db.session.query(DatasetProcessRule).delete()
|
||||
db.session.query(DocumentSegment).delete()
|
||||
db.session.query(Document).delete()
|
||||
db.session.query(Dataset).delete()
|
||||
db.session.query(UploadFile).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.commit()
|
||||
# Clear all test data using the provided session fixture
|
||||
db_session_with_containers.query(DatasetMetadataBinding).delete()
|
||||
db_session_with_containers.query(DatasetMetadata).delete()
|
||||
db_session_with_containers.query(AppDatasetJoin).delete()
|
||||
db_session_with_containers.query(DatasetQuery).delete()
|
||||
db_session_with_containers.query(DatasetProcessRule).delete()
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(UploadFile).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
redis_client.flushdb()
|
||||
@@ -103,10 +102,8 @@ class TestCleanDatasetTask:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
@@ -115,8 +112,8 @@ class TestCleanDatasetTask:
|
||||
status="active",
|
||||
)
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account relationship
|
||||
tenant_account_join = TenantAccountJoin(
|
||||
@@ -125,8 +122,8 @@ class TestCleanDatasetTask:
|
||||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
|
||||
db.session.add(tenant_account_join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant_account_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return account, tenant
|
||||
|
||||
@@ -155,10 +152,8 @@ class TestCleanDatasetTask:
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -194,10 +189,8 @@ class TestCleanDatasetTask:
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return document
|
||||
|
||||
@@ -232,10 +225,8 @@ class TestCleanDatasetTask:
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return segment
|
||||
|
||||
@@ -267,10 +258,8 @@ class TestCleanDatasetTask:
|
||||
used=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@@ -302,31 +291,29 @@ class TestCleanDatasetTask:
|
||||
)
|
||||
|
||||
# Verify results
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Check that dataset-related data was cleaned up
|
||||
documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(documents) == 0
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Check that metadata and bindings were cleaned up
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(metadata) == 0
|
||||
|
||||
bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(bindings) == 0
|
||||
|
||||
# Check that process rules and queries were cleaned up
|
||||
process_rules = db.session.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
|
||||
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(process_rules) == 0
|
||||
|
||||
queries = db.session.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
|
||||
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(queries) == 0
|
||||
|
||||
# Check that app dataset joins were cleaned up
|
||||
app_joins = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
|
||||
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(app_joins) == 0
|
||||
|
||||
# Verify index processor was called
|
||||
@@ -378,9 +365,7 @@ class TestCleanDatasetTask:
|
||||
import json
|
||||
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset metadata and bindings
|
||||
metadata = DatasetMetadata(
|
||||
@@ -403,11 +388,9 @@ class TestCleanDatasetTask:
|
||||
binding.id = str(uuid.uuid4())
|
||||
binding.created_at = datetime.now()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(metadata)
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(metadata)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the task
|
||||
clean_dataset_task(
|
||||
@@ -421,22 +404,24 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that metadata and bindings were cleaned up
|
||||
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_bindings = (
|
||||
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
assert len(remaining_bindings) == 0
|
||||
|
||||
# Verify index processor was called
|
||||
@@ -489,12 +474,13 @@ class TestCleanDatasetTask:
|
||||
mock_index_processor.clean.assert_called_once()
|
||||
|
||||
# Check that all data was cleaned up
|
||||
from extensions.ext_database import db
|
||||
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Recreate data for next test case
|
||||
@@ -540,14 +526,13 @@ class TestCleanDatasetTask:
|
||||
)
|
||||
|
||||
# Verify results - even with vector cleanup failure, documents and segments should be deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Check that documents were still deleted despite vector cleanup failure
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that segments were still deleted despite vector cleanup failure
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Verify that index processor was called and failed
|
||||
@@ -608,10 +593,8 @@ class TestCleanDatasetTask:
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the get_image_upload_file_ids function to return our image file IDs
|
||||
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids:
|
||||
@@ -629,16 +612,18 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all image files were deleted from database
|
||||
image_file_ids = [f.id for f in image_files]
|
||||
remaining_image_files = db.session.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
|
||||
remaining_image_files = (
|
||||
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
|
||||
)
|
||||
assert len(remaining_image_files) == 0
|
||||
|
||||
# Verify that storage.delete was called for each image file
|
||||
@@ -745,22 +730,24 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that all metadata and bindings were deleted
|
||||
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_bindings = (
|
||||
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
assert len(remaining_bindings) == 0
|
||||
|
||||
# Verify performance expectations
|
||||
@@ -808,9 +795,7 @@ class TestCleanDatasetTask:
|
||||
import json
|
||||
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock storage to raise exceptions
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
@@ -827,18 +812,13 @@ class TestCleanDatasetTask:
|
||||
)
|
||||
|
||||
# Verify results
|
||||
# Check that documents were still deleted despite storage failure
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that segments were still deleted despite storage failure
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
# Note: When storage operations fail, database deletions may be rolled back by implementation.
|
||||
# This test focuses on ensuring the task handles the exception and continues execution/logging.
|
||||
|
||||
# Check that upload file was still deleted from database despite storage failure
|
||||
# Note: When storage operations fail, the upload file may not be deleted
|
||||
# This demonstrates that the cleanup process continues even with storage errors
|
||||
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file.id).all()
|
||||
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
|
||||
# The upload file should still be deleted from the database even if storage cleanup fails
|
||||
# However, this depends on the specific implementation of clean_dataset_task
|
||||
if len(remaining_files) > 0:
|
||||
@@ -890,10 +870,8 @@ class TestCleanDatasetTask:
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create document with special characters in name
|
||||
special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~"
|
||||
@@ -912,8 +890,8 @@ class TestCleanDatasetTask:
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create segment with special characters and very long content
|
||||
long_content = "Very long content " * 100 # Long content within reasonable limits
|
||||
@@ -934,8 +912,8 @@ class TestCleanDatasetTask:
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create upload file with special characters in name
|
||||
special_filename = f"test_file_{special_content}.txt"
|
||||
@@ -952,14 +930,14 @@ class TestCleanDatasetTask:
|
||||
created_at=datetime.now(),
|
||||
used=False,
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update document with file reference
|
||||
import json
|
||||
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Save upload file ID for verification
|
||||
upload_file_id = upload_file.id
|
||||
@@ -975,8 +953,8 @@ class TestCleanDatasetTask:
|
||||
special_metadata.id = str(uuid.uuid4())
|
||||
special_metadata.created_at = datetime.now()
|
||||
|
||||
db.session.add(special_metadata)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(special_metadata)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the task
|
||||
clean_dataset_task(
|
||||
@@ -990,19 +968,19 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all()
|
||||
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that all metadata was deleted
|
||||
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
# Verify that storage.delete was called
|
||||
|
||||
+17
-34
@@ -24,16 +24,15 @@ class TestCreateSegmentToIndexTask:
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
"""Clean up database and Redis before each test to ensure isolation."""
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Clear all test data
|
||||
db.session.query(DocumentSegment).delete()
|
||||
db.session.query(Document).delete()
|
||||
db.session.query(Dataset).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.commit()
|
||||
# Clear all test data using fixture session
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
redis_client.flushdb()
|
||||
@@ -73,10 +72,8 @@ class TestCreateSegmentToIndexTask:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
@@ -84,8 +81,8 @@ class TestCreateSegmentToIndexTask:
|
||||
status="normal",
|
||||
plan="basic",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join with owner role
|
||||
join = TenantAccountJoin(
|
||||
@@ -94,8 +91,8 @@ class TestCreateSegmentToIndexTask:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
@@ -746,20 +743,9 @@ class TestCreateSegmentToIndexTask:
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
)
|
||||
|
||||
# Mock global database session to simulate transaction issues
|
||||
from extensions.ext_database import db
|
||||
|
||||
original_commit = db.session.commit
|
||||
commit_called = False
|
||||
|
||||
def mock_commit():
|
||||
nonlocal commit_called
|
||||
if not commit_called:
|
||||
commit_called = True
|
||||
raise Exception("Database commit failed")
|
||||
return original_commit()
|
||||
|
||||
db.session.commit = mock_commit
|
||||
# Simulate an error during indexing to trigger rollback path
|
||||
mock_processor = mock_external_service_dependencies["index_processor"]
|
||||
mock_processor.load.side_effect = Exception("Simulated indexing error")
|
||||
|
||||
# Act: Execute the task
|
||||
create_segment_to_index_task(segment.id)
|
||||
@@ -771,9 +757,6 @@ class TestCreateSegmentToIndexTask:
|
||||
assert segment.disabled_at is not None
|
||||
assert segment.error is not None
|
||||
|
||||
# Restore original commit method
|
||||
db.session.commit = original_commit
|
||||
|
||||
def test_create_segment_to_index_metadata_validation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
+14
-25
@@ -70,11 +70,9 @@ class TestDisableSegmentsFromIndexTask:
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant = tenant
|
||||
@@ -110,10 +108,8 @@ class TestDisableSegmentsFromIndexTask:
|
||||
built_in_field_enabled=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -158,10 +154,8 @@ class TestDisableSegmentsFromIndexTask:
|
||||
document.archived = False
|
||||
document.doc_form = "text_model" # Use text_model form for testing
|
||||
document.doc_language = "en"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return document
|
||||
|
||||
@@ -211,11 +205,9 @@ class TestDisableSegmentsFromIndexTask:
|
||||
|
||||
segments.append(segment)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
for segment in segments:
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return segments
|
||||
|
||||
@@ -645,15 +637,12 @@ class TestDisableSegmentsFromIndexTask:
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Mock db.session.close to verify it's called
|
||||
with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Verify session was closed
|
||||
mock_close.assert_called()
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Session lifecycle is managed by context manager; no explicit close assertion
|
||||
|
||||
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,6 @@ from faker import Faker
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from tasks.document_indexing_task import (
|
||||
@@ -75,15 +74,15 @@ class TestDocumentIndexingTasks:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@@ -92,8 +91,8 @@ class TestDocumentIndexingTasks:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
@@ -105,8 +104,8 @@ class TestDocumentIndexingTasks:
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
@@ -124,13 +123,13 @@ class TestDocumentIndexingTasks:
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
@@ -157,15 +156,15 @@ class TestDocumentIndexingTasks:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@@ -174,8 +173,8 @@ class TestDocumentIndexingTasks:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
@@ -187,8 +186,8 @@ class TestDocumentIndexingTasks:
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
@@ -206,10 +205,10 @@ class TestDocumentIndexingTasks:
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Configure billing features
|
||||
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||
@@ -219,7 +218,7 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
@@ -242,6 +241,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the task
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
@@ -250,7 +252,7 @@ class TestDocumentIndexingTasks:
|
||||
# Verify documents were updated to parsing status
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -310,6 +312,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the task with mixed document IDs
|
||||
_document_indexing(dataset.id, all_document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
@@ -317,7 +322,7 @@ class TestDocumentIndexingTasks:
|
||||
# Verify only existing documents were updated
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in existing_document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -353,6 +358,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the task
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
@@ -361,7 +369,7 @@ class TestDocumentIndexingTasks:
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing close the session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -400,7 +408,7 @@ class TestDocumentIndexingTasks:
|
||||
indexing_status="completed", # Already completed
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(doc1)
|
||||
db_session_with_containers.add(doc1)
|
||||
extra_documents.append(doc1)
|
||||
|
||||
# Document with disabled status
|
||||
@@ -417,10 +425,10 @@ class TestDocumentIndexingTasks:
|
||||
indexing_status="waiting",
|
||||
enabled=False, # Disabled
|
||||
)
|
||||
db.session.add(doc2)
|
||||
db_session_with_containers.add(doc2)
|
||||
extra_documents.append(doc2)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
all_documents = base_documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
@@ -428,6 +436,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the task with mixed document states
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
@@ -435,7 +446,7 @@ class TestDocumentIndexingTasks:
|
||||
# Verify all documents were updated to parsing status
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -482,20 +493,23 @@ class TestDocumentIndexingTasks:
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
extra_documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
all_documents = documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify error handling
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "batch upload" in updated_document.error
|
||||
@@ -526,6 +540,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the task with billing disabled
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify successful processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
@@ -533,7 +550,7 @@ class TestDocumentIndexingTasks:
|
||||
# Verify documents were updated to parsing status
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -565,6 +582,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the task
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
@@ -573,7 +593,7 @@ class TestDocumentIndexingTasks:
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -674,6 +694,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify core processing occurred (same as _document_indexing)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
@@ -681,7 +704,7 @@ class TestDocumentIndexingTasks:
|
||||
# Verify documents were updated (same as _document_indexing)
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -794,6 +817,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
@@ -802,7 +828,7 @@ class TestDocumentIndexingTasks:
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -865,6 +891,9 @@ class TestDocumentIndexingTasks:
|
||||
# Act: Execute the wrapper function for tenant1 only
|
||||
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
+66
-43
@@ -4,7 +4,6 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.duplicate_document_indexing_task import (
|
||||
@@ -82,15 +81,15 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@@ -99,8 +98,8 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
@@ -112,8 +111,8 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
@@ -132,13 +131,13 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
@@ -183,14 +182,14 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
indexing_at=fake.date_time_this_year(),
|
||||
created_by=dataset.created_by, # Add required field
|
||||
)
|
||||
db.session.add(segment)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh to ensure all relationships are loaded
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
|
||||
return dataset, documents, segments
|
||||
|
||||
@@ -217,15 +216,15 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@@ -234,8 +233,8 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
@@ -247,8 +246,8 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
@@ -267,10 +266,10 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Configure billing features
|
||||
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||
@@ -280,7 +279,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
@@ -305,6 +304,9 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
@@ -313,7 +315,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
# Verify documents were updated to parsing status
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -340,23 +342,32 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
segment_ids = [seg.id for seg in segments]
|
||||
|
||||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify segment cleanup
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify segment cleanup
|
||||
# Verify index processor clean was called for each document with segments
|
||||
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
|
||||
|
||||
# Verify segments were deleted from database
|
||||
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
|
||||
for segment in segments:
|
||||
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||
# Re-query segments from database using captured IDs to avoid stale ORM instances
|
||||
for seg_id in segment_ids:
|
||||
deleted_segment = (
|
||||
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first()
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -415,6 +426,9 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
# Act: Execute the task with mixed document IDs
|
||||
_duplicate_document_indexing_task(dataset.id, all_document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
@@ -422,7 +436,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
# Verify only existing documents were updated
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in existing_document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -458,6 +472,9 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
@@ -466,7 +483,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _duplicate_document_indexing_task close the session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
@@ -508,20 +525,23 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
extra_documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
all_documents = documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify error handling
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "batch upload" in updated_document.error.lower()
|
||||
@@ -557,10 +577,13 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
# Act: Execute the task with documents that will exceed vector space limit
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify error handling
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "limit" in updated_document.error.lower()
|
||||
@@ -620,11 +643,11 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
@@ -663,11 +686,11 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
mock_queue.delete_task_key.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
@@ -707,11 +730,11 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
mock_queue.delete_task_key.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
|
||||
@@ -49,10 +49,14 @@ def pipeline_id():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session with query capabilities."""
|
||||
with patch("tasks.clean_dataset_task.db") as mock_db:
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.clean_dataset_task.session_factory") as mock_sf:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
# context manager for create_session()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = mock_session
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
# Setup query chain
|
||||
mock_query = MagicMock()
|
||||
@@ -66,7 +70,10 @@ def mock_db_session():
|
||||
# Setup execute for JOIN queries
|
||||
mock_session.execute.return_value.all.return_value = []
|
||||
|
||||
yield mock_db
|
||||
# Yield an object with a `.session` attribute to keep tests unchanged
|
||||
wrapper = MagicMock()
|
||||
wrapper.session = mock_session
|
||||
yield wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -227,7 +234,9 @@ class TestBasicCleanup:
|
||||
|
||||
# Assert
|
||||
mock_db_session.session.delete.assert_any_call(mock_document)
|
||||
mock_db_session.session.delete.assert_any_call(mock_segment)
|
||||
# Segments are deleted in batch; verify a DELETE on document_segments was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
|
||||
def test_clean_dataset_task_deletes_related_records(
|
||||
@@ -413,7 +422,9 @@ class TestErrorHandling:
|
||||
|
||||
# Assert - documents and segments should still be deleted
|
||||
mock_db_session.session.delete.assert_any_call(mock_document)
|
||||
mock_db_session.session.delete.assert_any_call(mock_segment)
|
||||
# Segments are deleted in batch; verify a DELETE on document_segments was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
|
||||
def test_clean_dataset_task_storage_delete_failure_continues(
|
||||
@@ -461,7 +472,7 @@ class TestErrorHandling:
|
||||
[mock_segment], # segments
|
||||
]
|
||||
mock_get_image_upload_file_ids.return_value = [image_file_id]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
|
||||
mock_storage.delete.side_effect = Exception("Storage service unavailable")
|
||||
|
||||
# Act
|
||||
@@ -476,8 +487,9 @@ class TestErrorHandling:
|
||||
|
||||
# Assert - storage delete was attempted for image file
|
||||
mock_storage.delete.assert_called_with(mock_upload_file.key)
|
||||
# Image file should still be deleted from database
|
||||
mock_db_session.session.delete.assert_any_call(mock_upload_file)
|
||||
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_database_error_rollback(
|
||||
self,
|
||||
@@ -691,8 +703,10 @@ class TestSegmentAttachmentCleanup:
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_with(mock_attachment_file.key)
|
||||
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
|
||||
mock_db_session.session.delete.assert_any_call(mock_binding)
|
||||
# Attachment file and binding are deleted in batch; verify DELETEs were issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_attachment_storage_failure(
|
||||
self,
|
||||
@@ -734,9 +748,10 @@ class TestSegmentAttachmentCleanup:
|
||||
|
||||
# Assert - storage delete was attempted
|
||||
mock_storage.delete.assert_called_once()
|
||||
# Records should still be deleted from database
|
||||
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
|
||||
mock_db_session.session.delete.assert_any_call(mock_binding)
|
||||
# Records are deleted in batch; verify DELETEs were issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -784,7 +799,7 @@ class TestUploadFileCleanup:
|
||||
[mock_document], # documents
|
||||
[], # segments
|
||||
]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
|
||||
|
||||
# Act
|
||||
clean_dataset_task(
|
||||
@@ -798,7 +813,9 @@ class TestUploadFileCleanup:
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_with(mock_upload_file.key)
|
||||
mock_db_session.session.delete.assert_any_call(mock_upload_file)
|
||||
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_handles_missing_upload_file(
|
||||
self,
|
||||
@@ -832,7 +849,7 @@ class TestUploadFileCleanup:
|
||||
[mock_document], # documents
|
||||
[], # segments
|
||||
]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
# Act - should not raise exception
|
||||
clean_dataset_task(
|
||||
@@ -949,11 +966,11 @@ class TestImageFileCleanup:
|
||||
[mock_segment], # segments
|
||||
]
|
||||
|
||||
# Setup a mock query chain that returns files in sequence
|
||||
# Setup a mock query chain that returns files in batch (align with .in_().all())
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.first.side_effect = mock_image_files
|
||||
mock_where.all.return_value = mock_image_files
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
@@ -966,10 +983,10 @@ class TestImageFileCleanup:
|
||||
doc_form="paragraph_index",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_storage.delete.call_count == 2
|
||||
mock_storage.delete.assert_any_call("images/image-1.jpg")
|
||||
mock_storage.delete.assert_any_call("images/image-2.jpg")
|
||||
# Assert - each expected image key was deleted at least once
|
||||
calls = [c.args[0] for c in mock_storage.delete.call_args_list]
|
||||
assert "images/image-1.jpg" in calls
|
||||
assert "images/image-2.jpg" in calls
|
||||
|
||||
def test_clean_dataset_task_handles_missing_image_file(
|
||||
self,
|
||||
@@ -1010,7 +1027,7 @@ class TestImageFileCleanup:
|
||||
]
|
||||
|
||||
# Image file not found
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
# Act - should not raise exception
|
||||
clean_dataset_task(
|
||||
@@ -1086,14 +1103,15 @@ class TestEdgeCases:
|
||||
doc_form="paragraph_index",
|
||||
)
|
||||
|
||||
# Assert - all documents and segments should be deleted
|
||||
# Assert - all documents and segments should be deleted (documents per-entity, segments in batch)
|
||||
delete_calls = mock_db_session.session.delete.call_args_list
|
||||
deleted_items = [call[0][0] for call in delete_calls]
|
||||
|
||||
for doc in mock_documents:
|
||||
assert doc in deleted_items
|
||||
for seg in mock_segments:
|
||||
assert seg in deleted_items
|
||||
# Verify a batch DELETE on document_segments occurred
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_document_with_empty_data_source_info(
|
||||
self,
|
||||
|
||||
@@ -81,12 +81,25 @@ def mock_documents(document_ids, dataset_id):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -18,12 +18,18 @@ from tasks.delete_account_task import delete_account_task
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock the db.session used in delete_account_task."""
|
||||
with patch("tasks.delete_account_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
yield mock_session
|
||||
"""Mock session via session_factory.create_session()."""
|
||||
with patch("tasks.delete_account_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -109,13 +109,25 @@ def mock_document_segments(document_id):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests can observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -251,8 +263,8 @@ class TestDocumentIndexingSyncTask:
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# No session operations should be performed beyond the initial query
|
||||
mock_db_session.close.assert_not_called()
|
||||
# Session should still be closed via context manager teardown
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
@@ -286,9 +298,9 @@ class TestDocumentIndexingSyncTask:
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
@@ -94,13 +94,25 @@ def mock_document_segments(document_ids):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Allow tests to observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -200,8 +212,25 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test successful duplicate document indexing flow."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# Dataset via query.first()
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
# scalars() call sequence:
|
||||
# 1) documents list
|
||||
# 2..N) segments per document
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
# First call returns documents; subsequent calls return segments
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = mock_document_segments
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
@@ -264,8 +293,21 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test duplicate document indexing when billing limit is exceeded."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
# First scalars() -> documents; subsequent -> empty segments
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.TEAM
|
||||
@@ -294,8 +336,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test duplicate document indexing when IndexingRunner raises an error."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
@@ -318,8 +372,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test duplicate document indexing when document is paused."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
@@ -343,8 +409,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
):
|
||||
"""Test that duplicate document indexing cleans old segments."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = mock_document_segments
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
|
||||
# Act
|
||||
@@ -354,9 +432,9 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
# Verify clean was called for each document
|
||||
assert mock_processor.clean.call_count == len(mock_documents)
|
||||
|
||||
# Verify segments were deleted
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
# Verify segments were deleted in batch (DELETE FROM document_segments)
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -11,21 +11,18 @@ from tasks.remove_app_and_related_data_task import (
|
||||
|
||||
class TestDeleteDraftVariablesBatch:
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup):
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup):
|
||||
"""Test successful deletion of draft variables in batches."""
|
||||
app_id = "test-app-id"
|
||||
batch_size = 100
|
||||
|
||||
# Mock database connection and engine
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock two batches of results, then empty
|
||||
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
|
||||
@@ -68,7 +65,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
select_result3.__iter__.return_value = iter([])
|
||||
|
||||
# Configure side effects in the correct order
|
||||
mock_conn.execute.side_effect = [
|
||||
mock_session.execute.side_effect = [
|
||||
select_result1, # First SELECT
|
||||
delete_result1, # First DELETE
|
||||
select_result2, # Second SELECT
|
||||
@@ -86,54 +83,49 @@ class TestDeleteDraftVariablesBatch:
|
||||
assert result == 150
|
||||
|
||||
# Verify database calls
|
||||
assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes
|
||||
assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes
|
||||
|
||||
# Verify offload cleanup was called for both batches with file_ids
|
||||
expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)]
|
||||
expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)]
|
||||
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
|
||||
|
||||
# Simplified verification - check that the right number of calls were made
|
||||
# and that the SQL queries contain the expected patterns
|
||||
actual_calls = mock_conn.execute.call_args_list
|
||||
actual_calls = mock_session.execute.call_args_list
|
||||
for i, actual_call in enumerate(actual_calls):
|
||||
sql_text = str(actual_call[0][0])
|
||||
normalized = " ".join(sql_text.split())
|
||||
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
|
||||
# Verify it's a SELECT query that now includes file_id
|
||||
sql_text = str(actual_call[0][0])
|
||||
assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text
|
||||
assert "WHERE app_id = :app_id" in sql_text
|
||||
assert "LIMIT :batch_size" in sql_text
|
||||
assert "SELECT id, file_id FROM workflow_draft_variables" in normalized
|
||||
assert "WHERE app_id = :app_id" in normalized
|
||||
assert "LIMIT :batch_size" in normalized
|
||||
else: # DELETE calls (odd indices: 1, 3)
|
||||
# Verify it's a DELETE query
|
||||
sql_text = str(actual_call[0][0])
|
||||
assert "DELETE FROM workflow_draft_variables" in sql_text
|
||||
assert "WHERE id IN :ids" in sql_text
|
||||
assert "DELETE FROM workflow_draft_variables" in normalized
|
||||
assert "WHERE id IN :ids" in normalized
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup):
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup):
|
||||
"""Test deletion when no draft variables exist for the app."""
|
||||
app_id = "nonexistent-app-id"
|
||||
batch_size = 1000
|
||||
|
||||
# Mock database connection
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock empty result
|
||||
empty_result = MagicMock()
|
||||
empty_result.__iter__.return_value = iter([])
|
||||
mock_conn.execute.return_value = empty_result
|
||||
mock_session.execute.return_value = empty_result
|
||||
|
||||
result = delete_draft_variables_batch(app_id, batch_size)
|
||||
|
||||
assert result == 0
|
||||
assert mock_conn.execute.call_count == 1 # Only one select query
|
||||
assert mock_session.execute.call_count == 1 # Only one select query
|
||||
mock_offload_cleanup.assert_not_called() # No files to clean up
|
||||
|
||||
def test_delete_draft_variables_batch_invalid_batch_size(self):
|
||||
@@ -147,22 +139,19 @@ class TestDeleteDraftVariablesBatch:
|
||||
delete_draft_variables_batch(app_id, 0)
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
@patch("tasks.remove_app_and_related_data_task.logger")
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup):
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup):
|
||||
"""Test that batch deletion logs progress correctly."""
|
||||
app_id = "test-app-id"
|
||||
batch_size = 50
|
||||
|
||||
# Mock database
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock one batch then empty
|
||||
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
|
||||
@@ -183,7 +172,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
empty_result = MagicMock()
|
||||
empty_result.__iter__.return_value = iter([])
|
||||
|
||||
mock_conn.execute.side_effect = [
|
||||
mock_session.execute.side_effect = [
|
||||
# Select query result
|
||||
select_result,
|
||||
# Delete query result
|
||||
@@ -201,7 +190,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
|
||||
# Verify offload cleanup was called with file_ids
|
||||
if batch_file_ids:
|
||||
mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids)
|
||||
mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids)
|
||||
|
||||
# Verify logging calls
|
||||
assert mock_logging.info.call_count == 2
|
||||
@@ -261,19 +250,19 @@ class TestDeleteDraftVariableOffloadData:
|
||||
actual_calls = mock_conn.execute.call_args_list
|
||||
|
||||
# First call should be the SELECT query
|
||||
select_call_sql = str(actual_calls[0][0][0])
|
||||
select_call_sql = " ".join(str(actual_calls[0][0][0]).split())
|
||||
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
|
||||
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
|
||||
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
|
||||
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
|
||||
|
||||
# Second call should be DELETE upload_files
|
||||
delete_upload_call_sql = str(actual_calls[1][0][0])
|
||||
delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split())
|
||||
assert "DELETE FROM upload_files" in delete_upload_call_sql
|
||||
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
|
||||
|
||||
# Third call should be DELETE workflow_draft_variable_files
|
||||
delete_variable_files_call_sql = str(actual_calls[2][0][0])
|
||||
delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split())
|
||||
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
|
||||
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql
|
||||
|
||||
|
||||
Reference in New Issue
Block a user