Skip to content

检索优化

检索优化概述

检索是RAG系统的核心环节,检索质量直接影响最终生成效果。本文档介绍多种检索优化策略。

┌─────────────────────────────────────────────────────────────┐
│                    检索优化策略                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  🔍 查询优化        📊 检索策略        🎯 结果优化           │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐     │
│  │ Query改写    │  │ 混合检索      │  │ 重排序       │     │
│  │ 查询扩展    │  │ 多跳检索      │  │ MMR          │     │
│  │ HyDE        │  │ 子查询       │  │ 过滤        │     │
│  └──────────────┘  └──────────────┘  └──────────────┘     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

查询优化

1. 查询改写

python
async def rewrite_query(query):
    """使用LLM改写查询,提高检索召回率"""
    
    prompt = f"""
    请将以下用户问题改写成3个不同的表达方式,保持原意:
    
    原问题:{query}
    
    要求:
    1. 使用同义词替换
    2. 改变句式结构
    3. 添加可能的别名或简称
    
    输出格式:
    1. [改写1]
    2. [改写2]
    3. [改写3]
    """
    
    response = await llm.invoke(prompt)
    rewrites = extract_queries(response)
    
    return [query] + rewrites

2. HyDE(假设性文档嵌入)

python
async def hyde_retrieve(query):
    """
    HyDE: 先生成假设性答案,再用答案检索
    利用LLM生成"完美答案",用这个答案去找相似文档
    """
    
    # 1. 让LLM生成假设性答案
    prompt = f"""
    假设你是一个专家,请基于以下问题写一个详细、准确的回答:
    
    问题:{query}
    
    回答:(请写得详细、专业)
    """
    
    hypothetical_answer = await llm.invoke(prompt)
    
    # 2. 用假设性答案做检索
    results = await vector_store.similarity_search(
        hypothetical_answer,
        k=5
    )
    
    return results

3. 查询扩展

python
async def expand_query(query):
    """查询扩展:添加相关概念和同义词"""
    
    prompt = f"""
    请为以下查询扩展相关概念和关键词:
    
    原始查询:{query}
    
    扩展方向:
    1. 上位词(更通用的概念)
    2. 下位词(更具体的概念)
    3. 同义词
    4. 相关术语
    
    输出JSON格式:
    {{
        "original": "原始查询",
        "expanded_terms": ["扩展词1", "扩展词2", ...],
        "expanded_query": "组合后的查询"
    }}
    """
    
    response = await llm.invoke(prompt)
    return json.loads(response)

混合检索

语义搜索 + 关键词搜索

python
async def hybrid_search(query, top_k=10):
    """混合搜索:结合语义和关键词"""
    
    # 1. 语义搜索
    semantic_results = await vector_store.asimilarity_search_with_score(
        query,
        k=top_k
    )
    
    # 2. BM25关键词搜索
    keyword_results = await keyword_search_bm25(
        query,
        documents,
        k=top_k
    )
    
    # 3. RRF融合
    def rrf_fusion(results_list, k=60):
        """倒数排名融合"""
        scores = {}
        
        for results in results_list:
            for rank, (_, score) in enumerate(results):
                doc_id = results[rank][0].id
                scores[doc_id] = scores.get(doc_id, 0) + 1 / (k + rank + 1)
        
        return sorted(scores.items(), key=lambda x: x[1], reverse=True)
    
    fused = rrf_fusion([semantic_results, keyword_results])
    
    return fused

BM25关键词搜索实现

python
import math
from collections import Counter

class BM25:
    def __init__(self, k1=1.5, b=0.75):
        self.k1 = k1
        self.b = b
        self.doc_freqs = {}
        self.avgdl = 0
    
    def fit(self, corpus):
        """构建BM25索引"""
        self.doc_freqs = {}
        self.avgdl = 0
        nd = {}  # 词频
        
        for document in corpus:
            self.avgdl += len(document)
            frequencies = Counter(document)
            
            for word, freq in frequencies.items():
                self.doc_freqs[word] = self.doc_freqs.get(word, 0) + 1
            
            for word in set(document):
                nd[word] = nd.get(word, 0) + 1
        
        self.avgdl = self.avgdl / len(corpus)
        self.N = len(corpus)
        
        return self
    
    def score(self, query, document):
        """计算BM25得分"""
        scores = []
        doc_len = len(document)
        frequencies = Counter(document)
        
        for word in query:
            if word not in self.doc_freqs:
                continue
            
            freq = frequencies.get(word, 0)
            n = self.doc_freqs[word]
            
            # IDF
            idf = math.log((self.N - n + 0.5) / (n + 0.5) + 1)
            
            # TF normalization
            tf_norm = (freq * (self.k1 + 1)) / (
                freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)
            )
            
            scores.append(idf * tf_norm)
        
        return sum(scores)

重排序(Rerank)

使用Cross-Encoder重排序

python
from sentence_transformers import CrossEncoder

# 加载重排序模型
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

async def rerank(query, candidates, top_k=5):
    """
    使用Cross-Encoder对候选文档重排序
    Cross-Encoder: 同时编码query和document,更精准但更慢
    """
    
    # 构建query-document对
    pairs = [(query, doc) for doc in candidates]
    
    # 计算相关性分数
    scores = reranker.predict(pairs)
    
    # 按分数排序
    ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
    
    return ranked[:top_k]

两阶段检索

python
async def two_stage_retrieve(query, top_k_initial=20, top_k_final=5):
    """
    两阶段检索:
    1. 快速向量检索(召回)
    2. 精确重排序(精排)
    """
    
    # 第一阶段:向量检索,召回候选集
    candidates = await vector_store.similarity_search(
        query,
        k=top_k_initial
    )
    
    # 第二阶段:重排序
    reranked = await rerank(query, candidates, top_k_final)
    
    return reranked

MMR(最大边际相关)

实现多样化检索

python
async def mmr_retrieve(query, k=5, fetch_k=20, lambda_mult=0.5):
    """
    MMR: 在相关性和多样性之间取得平衡
    
    lambda_mult: 
    - 接近1: 更注重相关性
    - 接近0: 更注重多样性
    """
    
    # 1. 初步检索更多候选
    candidates = await vector_store.similarity_search(
        query,
        k=fetch_k
    )
    
    selected = []
    remaining = candidates.copy()
    
    while len(selected) < k and remaining:
        best_score = -float('inf')
        best_doc = None
        
        for doc in remaining:
            # 相关性分数
            relevance = await compute_relevance(query, doc)
            
            # 最大相似度(与已选文档)
            max_similarity = 0
            if selected:
                similarities = await compute_similarities(doc, selected)
                max_similarity = max(similarities)
            
            # MMR分数
            mmr_score = lambda_mult * relevance - (1 - lambda_mult) * max_similarity
            
            if mmr_score > best_score:
                best_score = mmr_score
                best_doc = doc
        
        if best_doc:
            selected.append(best_doc)
            remaining.remove(best_doc)
    
    return selected

迭代检索

多轮检索增强

python
async def iterative_retrieve(query, max_iterations=3):
    """
    迭代检索:每轮根据上轮结果优化检索
    """
    
    context = ""
    all_results = []
    current_query = query
    
    for i in range(max_iterations):
        # 检索
        results = await vector_store.similarity_search(
            current_query,
            k=3
        )
        
        # 检查是否需要继续
        if is_sufficient(results):
            all_results.extend(results)
            break
        
        # 聚合已有信息,生成更好的查询
        context = "\n".join([doc.content for doc in results])
        
        current_query = await generate_better_query(
            original_query=query,
            gathered_context=context
        )
        
        all_results.extend(results)
    
    return deduplicate(all_results)

过滤与条件检索

元数据过滤

python
# Milvus/Pinecone 等支持元数据过滤
results = await vector_store.similarity_search(
    query,
    k=10,
    filter={
        "category": {"$in": ["技术", "教程"]},
        "date": {"$gte": "2024-01-01"},
        "author": {"$ne": "admin"}
    }
)

混合过滤与检索

python
async def filtered_hybrid_search(query, filters, k=10):
    """结合过滤和混合检索"""
    
    # 1. 获取候选文档
    candidates = await hybrid_search(query, top_k=k * 3)
    
    # 2. 应用过滤
    filtered = [
        (doc, score) 
        for doc, score in candidates 
        if match_filters(doc.metadata, filters)
    ]
    
    # 3. 返回top_k
    return filtered[:k]

检索评估指标

python
class RetrievalEvaluator:
    """检索效果评估"""
    
    def precision_at_k(self, retrieved, relevant, k):
        """Precision@K"""
        retrieved_k = retrieved[:k]
        return len(set(retrieved_k) & set(relevant)) / k
    
    def recall_at_k(self, retrieved, relevant, k):
        """Recall@K"""
        retrieved_k = retrieved[:k]
        return len(set(retrieved_k) & set(relevant)) / len(relevant)
    
    def ndcg_at_k(self, retrieved, relevant, k):
        """NDCG@K"""
        dcg = 0
        for i, doc in enumerate(retrieved[:k]):
            if doc in relevant:
                dcg += 1 / math.log2(i + 2)
        
        idcg = sum(1 / math.log2(i + 2) for i in range(min(len(relevant), k)))
        
        return dcg / idcg if idcg > 0 else 0

Released under the MIT License.