diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 303c3fe31..d66829837 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -188,14 +188,17 @@ class OracleVector(BaseVector): def text_exists(self, id: str) -> bool: with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) + cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,)) return cur.fetchone() is not None conn.close() def get_by_ids(self, ids: list[str]) -> list[Document]: + if not ids: + return [] with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) + cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids) docs = [] for record in cur: docs.append(Document(page_content=record[1], metadata=record[0])) @@ -208,14 +211,15 @@ class OracleVector(BaseVector): return with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) + placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) + cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids) conn.commit() conn.close() def delete_by_metadata_field(self, key: str, value: str) -> None: with self._get_connection() as conn: with conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,)) conn.commit() conn.close() @@ -227,12 +231,20 @@ class OracleVector(BaseVector): :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ + # Validate and sanitize top_k to prevent SQL injection top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: + top_k = 4 # Use default if invalid + document_ids_filter = kwargs.get("document_ids_filter") where_clause = "" + params = [numpy.array(query_vector)] + if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" + placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter))) + where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})" + params.extend(document_ids_filter) + with self._get_connection() as conn: conn.inputtypehandler = self.input_type_handler conn.outputtypehandler = self.output_type_handler @@ -241,7 +253,7 @@ class OracleVector(BaseVector): f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) AS distance FROM {self.table_name} {where_clause} ORDER BY distance fetch first {top_k} rows only""", - [numpy.array(query_vector)], + params, ) docs = [] score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -259,7 +271,10 @@ class OracleVector(BaseVector): import nltk # type: ignore from nltk.corpus import stopwords # type: ignore + # Validate and sanitize top_k to prevent SQL injection top_k = kwargs.get("top_k", 5) + if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: + top_k = 5 # Use default if invalid # just not implement fetch by score_threshold now, may be later score_threshold = float(kwargs.get("score_threshold") or 0.0) if len(query) > 0: @@ -297,14 +312,21 @@ class OracleVector(BaseVector): with conn.cursor() as cur: document_ids_filter = kwargs.get("document_ids_filter") where_clause = "" + params: dict[str, Any] = {"kk": " ACCUM ".join(entities)} + if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f" AND metadata->>'document_id' in ({document_ids}) " + placeholders = [] + for i, doc_id in enumerate(document_ids_filter): + param_name = f"doc_id_{i}" + placeholders.append(f":{param_name}") + params[param_name] = doc_id + where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) " + cur.execute( f"""select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} order by score(1) desc fetch first {top_k} rows only""", - kk=" ACCUM ".join(entities), + params, ) docs = [] for record in cur: