检索优化
检索优化概述
检索是RAG系统的核心环节,检索质量直接影响最终生成效果。本文档介绍多种检索优化策略。
┌─────────────────────────────────────────────────────────────┐
│ 检索优化策略 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 🔍 查询优化 📊 检索策略 🎯 结果优化 │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Query改写 │ │ 混合检索 │ │ 重排序 │ │
│ │ 查询扩展 │ │ 多跳检索 │ │ MMR │ │
│ │ HyDE │ │ 子查询 │ │ 过滤 │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘查询优化
1. 查询改写
python
async def rewrite_query(query):
"""使用LLM改写查询,提高检索召回率"""
prompt = f"""
请将以下用户问题改写成3个不同的表达方式,保持原意:
原问题:{query}
要求:
1. 使用同义词替换
2. 改变句式结构
3. 添加可能的别名或简称
输出格式:
1. [改写1]
2. [改写2]
3. [改写3]
"""
response = await llm.invoke(prompt)
rewrites = extract_queries(response)
return [query] + rewrites2. HyDE(假设性文档嵌入)
python
async def hyde_retrieve(query):
"""
HyDE: 先生成假设性答案,再用答案检索
利用LLM生成"完美答案",用这个答案去找相似文档
"""
# 1. 让LLM生成假设性答案
prompt = f"""
假设你是一个专家,请基于以下问题写一个详细、准确的回答:
问题:{query}
回答:(请写得详细、专业)
"""
hypothetical_answer = await llm.invoke(prompt)
# 2. 用假设性答案做检索
results = await vector_store.similarity_search(
hypothetical_answer,
k=5
)
return results3. 查询扩展
python
async def expand_query(query):
"""查询扩展:添加相关概念和同义词"""
prompt = f"""
请为以下查询扩展相关概念和关键词:
原始查询:{query}
扩展方向:
1. 上位词(更通用的概念)
2. 下位词(更具体的概念)
3. 同义词
4. 相关术语
输出JSON格式:
{{
"original": "原始查询",
"expanded_terms": ["扩展词1", "扩展词2", ...],
"expanded_query": "组合后的查询"
}}
"""
response = await llm.invoke(prompt)
return json.loads(response)混合检索
语义搜索 + 关键词搜索
python
async def hybrid_search(query, top_k=10):
"""混合搜索:结合语义和关键词"""
# 1. 语义搜索
semantic_results = await vector_store.asimilarity_search_with_score(
query,
k=top_k
)
# 2. BM25关键词搜索
keyword_results = await keyword_search_bm25(
query,
documents,
k=top_k
)
# 3. RRF融合
def rrf_fusion(results_list, k=60):
"""倒数排名融合"""
scores = {}
for results in results_list:
for rank, (_, score) in enumerate(results):
doc_id = results[rank][0].id
scores[doc_id] = scores.get(doc_id, 0) + 1 / (k + rank + 1)
return sorted(scores.items(), key=lambda x: x[1], reverse=True)
fused = rrf_fusion([semantic_results, keyword_results])
return fusedBM25关键词搜索实现
python
import math
from collections import Counter
class BM25:
def __init__(self, k1=1.5, b=0.75):
self.k1 = k1
self.b = b
self.doc_freqs = {}
self.avgdl = 0
def fit(self, corpus):
"""构建BM25索引"""
self.doc_freqs = {}
self.avgdl = 0
nd = {} # 词频
for document in corpus:
self.avgdl += len(document)
frequencies = Counter(document)
for word, freq in frequencies.items():
self.doc_freqs[word] = self.doc_freqs.get(word, 0) + 1
for word in set(document):
nd[word] = nd.get(word, 0) + 1
self.avgdl = self.avgdl / len(corpus)
self.N = len(corpus)
return self
def score(self, query, document):
"""计算BM25得分"""
scores = []
doc_len = len(document)
frequencies = Counter(document)
for word in query:
if word not in self.doc_freqs:
continue
freq = frequencies.get(word, 0)
n = self.doc_freqs[word]
# IDF
idf = math.log((self.N - n + 0.5) / (n + 0.5) + 1)
# TF normalization
tf_norm = (freq * (self.k1 + 1)) / (
freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)
)
scores.append(idf * tf_norm)
return sum(scores)重排序(Rerank)
使用Cross-Encoder重排序
python
from sentence_transformers import CrossEncoder
# 加载重排序模型
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
async def rerank(query, candidates, top_k=5):
"""
使用Cross-Encoder对候选文档重排序
Cross-Encoder: 同时编码query和document,更精准但更慢
"""
# 构建query-document对
pairs = [(query, doc) for doc in candidates]
# 计算相关性分数
scores = reranker.predict(pairs)
# 按分数排序
ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
return ranked[:top_k]两阶段检索
python
async def two_stage_retrieve(query, top_k_initial=20, top_k_final=5):
"""
两阶段检索:
1. 快速向量检索(召回)
2. 精确重排序(精排)
"""
# 第一阶段:向量检索,召回候选集
candidates = await vector_store.similarity_search(
query,
k=top_k_initial
)
# 第二阶段:重排序
reranked = await rerank(query, candidates, top_k_final)
return rerankedMMR(最大边际相关)
实现多样化检索
python
async def mmr_retrieve(query, k=5, fetch_k=20, lambda_mult=0.5):
"""
MMR: 在相关性和多样性之间取得平衡
lambda_mult:
- 接近1: 更注重相关性
- 接近0: 更注重多样性
"""
# 1. 初步检索更多候选
candidates = await vector_store.similarity_search(
query,
k=fetch_k
)
selected = []
remaining = candidates.copy()
while len(selected) < k and remaining:
best_score = -float('inf')
best_doc = None
for doc in remaining:
# 相关性分数
relevance = await compute_relevance(query, doc)
# 最大相似度(与已选文档)
max_similarity = 0
if selected:
similarities = await compute_similarities(doc, selected)
max_similarity = max(similarities)
# MMR分数
mmr_score = lambda_mult * relevance - (1 - lambda_mult) * max_similarity
if mmr_score > best_score:
best_score = mmr_score
best_doc = doc
if best_doc:
selected.append(best_doc)
remaining.remove(best_doc)
return selected迭代检索
多轮检索增强
python
async def iterative_retrieve(query, max_iterations=3):
"""
迭代检索:每轮根据上轮结果优化检索
"""
context = ""
all_results = []
current_query = query
for i in range(max_iterations):
# 检索
results = await vector_store.similarity_search(
current_query,
k=3
)
# 检查是否需要继续
if is_sufficient(results):
all_results.extend(results)
break
# 聚合已有信息,生成更好的查询
context = "\n".join([doc.content for doc in results])
current_query = await generate_better_query(
original_query=query,
gathered_context=context
)
all_results.extend(results)
return deduplicate(all_results)过滤与条件检索
元数据过滤
python
# Milvus/Pinecone 等支持元数据过滤
results = await vector_store.similarity_search(
query,
k=10,
filter={
"category": {"$in": ["技术", "教程"]},
"date": {"$gte": "2024-01-01"},
"author": {"$ne": "admin"}
}
)混合过滤与检索
python
async def filtered_hybrid_search(query, filters, k=10):
"""结合过滤和混合检索"""
# 1. 获取候选文档
candidates = await hybrid_search(query, top_k=k * 3)
# 2. 应用过滤
filtered = [
(doc, score)
for doc, score in candidates
if match_filters(doc.metadata, filters)
]
# 3. 返回top_k
return filtered[:k]检索评估指标
python
class RetrievalEvaluator:
"""检索效果评估"""
def precision_at_k(self, retrieved, relevant, k):
"""Precision@K"""
retrieved_k = retrieved[:k]
return len(set(retrieved_k) & set(relevant)) / k
def recall_at_k(self, retrieved, relevant, k):
"""Recall@K"""
retrieved_k = retrieved[:k]
return len(set(retrieved_k) & set(relevant)) / len(relevant)
def ndcg_at_k(self, retrieved, relevant, k):
"""NDCG@K"""
dcg = 0
for i, doc in enumerate(retrieved[:k]):
if doc in relevant:
dcg += 1 / math.log2(i + 2)
idcg = sum(1 / math.log2(i + 2) for i in range(min(len(relevant), k)))
return dcg / idcg if idcg > 0 else 0