Skip to content

记忆管理

Agent记忆概述

记忆是Agent的核心能力之一,决定了Agent能否基于历史上下文进行连续推理和个性化服务。

┌─────────────────────────────────────────────────────────────┐
│                    Agent 记忆系统架构                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                    记忆存储层                        │   │
│  │                                                     │   │
│  │  ┌─────────────┐  ┌─────────────┐  ┌───────────┐ │   │
│  │  │ 感官记忆    │→ │ 短期记忆    │→ │ 长期记忆  │ │   │
│  │  │ (感知输入)  │  │ (工作记忆)  │  │ (向量存储)│ │   │
│  │  └─────────────┘  └─────────────┘  └───────────┘ │   │
│  │        │                │                │        │   │
│  │        └────────────────┼────────────────┘        │   │
│  │                         ▼                         │   │
│  │              ┌───────────────────┐                │   │
│  │              │     记忆检索       │                │   │
│  │              │  (相关记忆召回)   │                │   │
│  │              └───────────────────┘                │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

记忆类型

1. 短期记忆(Working Memory)

python
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Dict, Any

@dataclass
class Message:
    """单条消息"""
    role: str  # user/assistant/system
    content: str
    timestamp: datetime = field(default_factory=datetime.now)
    metadata: Dict[str, Any] = field(default_factory=dict)

class ShortTermMemory:
    """短期记忆:当前会话的上下文"""
    
    def __init__(self, max_messages: int = 50):
        self.max_messages = max_messages
        self.messages: List[Message] = []
    
    def add(self, role: str, content: str, metadata: Dict = None):
        """添加消息"""
        self.messages.append(Message(
            role=role,
            content=content,
            metadata=metadata or {}
        ))
        
        # 超出容量时触发压缩
        if len(self.messages) > self.max_messages:
            self.compress()
    
    def get_context(self, last_n: int = None) -> str:
        """获取对话上下文"""
        messages = self.messages[-last_n:] if last_n else self.messages
        return "\n".join([
            f"{msg.role}: {msg.content}"
            for msg in messages
        ])
    
    def compress(self):
        """压缩短期记忆:摘要化早期内容"""
        if len(self.messages) <= self.max_messages // 2:
            return
        
        # 保留最近的50%
        keep_count = self.max_messages // 2
        recent = self.messages[-keep_count:]
        old = self.messages[:-keep_count]
        
        # 生成摘要
        summary = self._summarize(old)
        
        # 替换为摘要消息
        self.messages = [
            Message(
                role="system",
                content=f"[早期对话摘要] {summary}",
                metadata={"type": "summary"}
            )
        ] + recent
    
    def _summarize(self, messages: List[Message]) -> str:
        """生成摘要"""
        # 调用LLM生成摘要
        content = "\n".join([m.content for m in messages])
        prompt = f"请总结以下对话的关键信息:\n{content[:2000]}"
        # ... 调用LLM
        return "用户询问了X问题,AI提供了Y建议..."

2. 长期记忆(Long-term Memory)

python
import chromadb
from datetime import datetime

class LongTermMemory:
    """长期记忆:持久化存储,可检索"""
    
    def __init__(self, collection_name: str = "agent_memory"):
        self.client = chromadb.Client()
        self.collection = self.client.create_collection(
            name=collection_name,
            metadata={"hnsw:space": "cosine"}
        )
    
    def store(self, content: str, memory_type: str, metadata: Dict = None):
        """
        存储记忆
        
        memory_type: episodic(事件记忆) / semantic(语义记忆) / procedural(程序记忆)
        """
        from langchain_openai import OpenAIEmbeddings
        
        embeddings = OpenAIEmbeddings()
        vector = embeddings.embed_query(content)
        
        self.collection.add(
            documents=[content],
            embeddings=[vector],
            ids=[str(datetime.now().timestamp())],
            metadatas=[{
                "type": memory_type,
                "timestamp": datetime.now().isoformat(),
                **(metadata or {})
            }]
        )
    
    def retrieve(self, query: str, memory_type: str = None, k: int = 5) -> List[Dict]:
        """检索相关记忆"""
        from langchain_openai import OpenAIEmbeddings
        
        embeddings = OpenAIEmbeddings()
        query_vector = embeddings.embed_query(query)
        
        where = {"type": memory_type} if memory_type else None
        
        results = self.collection.query(
            query_embeddings=[query_vector],
            n_results=k,
            where=where
        )
        
        return [
            {
                "content": doc,
                "metadata": meta,
                "distance": dist
            }
            for doc, meta, dist in zip(
                results["documents"][0],
                results["metadatas"][0],
                results["distances"][0]
            )
        ]

3. 程序记忆(Procedural Memory)

python
class ProceduralMemory:
    """程序记忆:Agent的学习模式和执行流程"""
    
    def __init__(self):
        self.patterns = {}  # 存储常见模式
    
    def learn_pattern(self, task_type: str, solution: str):
        """学习任务解决模式"""
        self.patterns[task_type] = {
            "solution": solution,
            "success_count": 1,
            "last_used": datetime.now()
        }
    
    def recall_pattern(self, task_type: str) -> str:
        """回忆解决模式"""
        if task_type in self.patterns:
            pattern = self.patterns[task_type]
            pattern["last_used"] = datetime.now()
            return pattern["solution"]
        return None
    
    def update_success(self, task_type: str, success: bool):
        """更新模式成功率"""
        if task_type in self.patterns:
            pattern = self.patterns[task_type]
            pattern["success_count"] += 1 if success else 0
            pattern["total_attempts"] = pattern.get("total_attempts", 0) + 1

记忆管理系统

python
class AgentMemorySystem:
    """统一的记忆管理系统"""
    
    def __init__(self, config: Dict):
        self.short_term = ShortTermMemory(config.get("max_short_term", 50))
        self.long_term = LongTermMemory(config.get("collection", "memory"))
        self.procedural = ProceduralMemory()
    
    def add_interaction(self, role: str, content: str, metadata: Dict = None):
        """记录交互"""
        self.short_term.add(role, content, metadata)
        
        # 重要内容存入长期记忆
        if metadata and metadata.get("important"):
            self.long_term.store(
                content=content,
                memory_type="important_interaction",
                metadata=metadata
            )
    
    def remember(self, query: str, context_limit: int = 10) -> str:
        """
        获取相关记忆构建上下文
        
        1. 获取短期记忆(最近的对话)
        2. 检索长期记忆(相关历史)
        3. 组装上下文
        """
        # 短期记忆
        short_context = self.short_term.get_context(context_limit)
        
        # 长期记忆检索
        relevant_memories = self.long_term.retrieve(query, k=3)
        long_context = "\n".join([
            f"[记忆] {m['content']}"
            for m in relevant_memories
        ])
        
        # 组装
        if long_context:
            return f"{short_context}\n\n{long_context}"
        return short_context
    
    def store_learning(self, task: str, solution: str, outcome: str):
        """存储学习成果"""
        self.long_term.store(
            content=f"任务: {task}\n解决方案: {solution}\n结果: {outcome}",
            memory_type="learning",
            metadata={"task_category": categorize(task)}
        )
        
        # 更新程序记忆
        self.procedural.learn_pattern(categorize(task), solution)

记忆检索策略

1. 基于时间的检索

python
from datetime import timedelta

def retrieve_recent(self, query: str, days: int = 7, k: int = 5):
    """检索最近N天的记忆"""
    cutoff = (datetime.now() - timedelta(days=days)).isoformat()
    
    results = self.collection.query(
        query_embeddings=[embed(query)],
        n_results=k,
        where={"timestamp": {"$gte": cutoff}}
    )
    
    return results

2. 基于重要性的检索

python
def retrieve_important(self, query: str, k: int = 5):
    """优先检索重要记忆"""
    results = self.collection.query(
        query_embeddings=[embed(query)],
        n_results=k * 2,  # 多取一些
    )
    
    # 按重要性和相关性排序
    scored = [
        {
            **r,
            "score": r["distance"] * (1 / r["metadata"].get("importance", 1))
        }
        for r in results
    ]
    
    return sorted(scored, key=lambda x: x["score"])[:k]

3. 混合检索

python
async def hybrid_retrieve(self, query: str, k: int = 5):
    """
    混合检索:结合语义、时间和重要性
    """
    # 1. 语义检索
    semantic_results = await self.semantic_search(query, k * 2)
    
    # 2. 时间加权
    time_weighted = self.apply_time_decay(semantic_results)
    
    # 3. 重要性加权
    importance_weighted = self.apply_importance(time_weighted)
    
    # 4. 去重
    deduped = self.deduplicate(importance_weighted)
    
    return deduped[:k]

记忆优化策略

1. 记忆遗忘机制

python
class MemoryForgetting:
    """记忆遗忘:定期清理低价值记忆"""
    
    def __init__(self, decay_rate: float = 0.95):
        self.decay_rate = decay_rate
    
    def decay_memory(self, memory):
        """
        记忆衰减:最近使用少且相关性低的记忆逐渐衰减
        """
        days_since_use = (datetime.now() - memory["last_accessed"]).days
        relevance = 1 - memory["distance"]  # 向量距离转相关性
        
        # 计算衰减后的价值
        value = relevance * (self.decay_rate ** days_since_use)
        
        return value
    
    def should_forget(self, memory, threshold: float = 0.1):
        """判断是否应该遗忘"""
        return self.decay_memory(memory) < threshold
    
    def cleanup(self, memories, threshold: float = 0.1):
        """清理低价值记忆"""
        return [m for m in memories if not self.should_forget(m, threshold)]

2. 记忆压缩

python
async def compress_memories(memories: List[str], llm) -> str:
    """
    使用LLM压缩多条记忆为一条摘要
    """
    prompt = f"""
    请将以下多条记忆压缩为一条简洁的摘要,保留关键信息:
    
    记忆列表:
    {' '.join(memories)}
    
    要求:
    1. 保留关键事实和结论
    2. 去除冗余信息
    3. 输出简洁明了
    """
    
    response = await llm.invoke(prompt)
    return response

3. 记忆优先级

python
class MemoryPriority:
    """记忆优先级管理"""
    
    PRIORITY_HIGH = 3
    PRIORITY_MEDIUM = 2
    PRIORITY_LOW = 1
    
    def assess_priority(self, content: str, context: Dict) -> int:
        """
        评估记忆优先级
        """
        # 1. 明确标记的优先
        if context.get("important"):
            return self.PRIORITY_HIGH
        
        # 2. 用户明确表达的兴趣
        if any(kw in content for kw in ["喜欢", "想要", "关心", "需要"]):
            return self.PRIORITY_MEDIUM
        
        # 3. 事实性信息
        if self.is_factual(content):
            return self.PRIORITY_MEDIUM
        
        return self.PRIORITY_LOW

最佳实践

┌─────────────────────────────────────────────────────────────┐
│                   记忆管理 Checklist                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ✅ 分离短期和长期记忆,职责明确                             │
│  ✅ 短期记忆设置容量限制,防止无限增长                       │
│  ✅ 定期将重要短期记忆转存长期记忆                         │
│  ✅ 长期记忆使用向量检索,提高召回效率                     │
│  ✅ 实现记忆遗忘机制,避免存储膨胀                         │
│  ✅ 记忆检索时考虑时间衰减和重要性                         │
│  ✅ 记录学习成果,优化未来决策                             │
│  ✅ 保护隐私信息,不存储敏感内容                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Released under the MIT License.