fix: replace all dataset.Model.query to db.session.query(Model) (#19509)

This commit is contained in:
非法操作
2025-05-12 13:52:33 +08:00
committed by GitHub
parent 982a8ac61a
commit fa226ece81
21 changed files with 430 additions and 265 deletions
@@ -46,14 +46,22 @@ class DatasetIndexToolCallbackHandler:
DatasetDocument.id == document.metadata["document_id"]
).first()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
)
else:
query = db.session.query(DocumentSegment).filter(
+16 -12
View File
@@ -51,7 +51,7 @@ class IndexingRunner:
for dataset_document in dataset_documents:
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@@ -103,15 +103,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is splitting."""
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id, document_id=dataset_document.id
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
for document_segment in document_segments:
db.session.delete(document_segment)
@@ -162,15 +164,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is indexing."""
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id, document_id=dataset_document.id
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
documents = []
if document_segments:
@@ -254,7 +258,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
@@ -587,7 +591,7 @@ class IndexingRunner:
@staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context():
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
@@ -676,7 +680,7 @@ class IndexingRunner:
"""
Update the document segment by document id.
"""
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
def _transform(
+17 -8
View File
@@ -237,7 +237,7 @@ class DatasetRetrieval:
if show_retrieve_source:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
@@ -511,14 +511,23 @@ class DatasetRetrieval:
).first()
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
db.session.commit()
else:
@@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
)
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document:
source = {
"position": resource_number,
@@ -185,7 +185,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
@@ -275,12 +275,16 @@ class KnowledgeRetrievalNode(LLMNode):
if records:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document:
source = {
"metadata": {