第29章:RAG检索增强生成技术
"知识的力量不在于拥有多少,而在于能否在需要时精准检索并智能应用。"
🎯 本章学习目标
📚 知识目标
- 理解RAG核心原理:掌握检索增强生成的技术架构和工作机制
- 掌握向量数据库技术:学习文档向量化、存储和检索的完整流程
- 熟悉检索策略优化:理解密集检索、稀疏检索和混合检索策略
- 了解生成质量控制:学习基于检索的答案生 成和质量评估方法
🛠️ 技能目标
- 设计RAG系统架构:能够设计完整的检索增强生成系统
- 实现向量检索引擎:开发高效的语义检索和匹配算法
- 构建知识问答系统:建立企业级智能问答解决方案
- 优化检索生成质量:掌握检索精度和生成质量的平衡技术
🌟 素养目标
- 信息检索思维:培养系统性的知识管理和信息检索理念
- 工程化意识:建立大规模知识库系统的工程化思维
- 创新应用能力:具备将RAG技术应用于实际业务场景的能力
🏢 欢迎来到知识检索中心
经过前面章节对AI模型技术的深入学习,我们已经掌握了从基础机器学习到高级智能体开发的完整技术栈。现在,让我们走进一个全新的智能世界——知识检索中心!
🌆 知识检索中心全景图
想象一下,你正站在一座现代化的智能信息大厦前,这里是知识检索中心的总部:
🎭 从记忆到检索的智能进化
如果说传统的AI模型是一位博学的学者,那么RAG系统就是一座拥有无限扩展能力的智能图书馆:
- 📚 海量知识存储:不再受限于模型参数,可以无限扩展知识库
- 🔍 精准信息检索:通过语义理解快速定位相关信息
- 🧠 智能答案生成:结合检索到的知识生成准确、有根据的回答
- 🔄 实时知识更新:可以随时添加新知识而无需重新训练模型
29.1 RAG系统概述与核心原理
🧭 什么是RAG检索增强生成
**RAG(Retrieval-Augmented Generation)**是一种将信息检索与文本生成相结合的AI技术架构。它通过在 生成过程中动态检索相关知识,显著提升了语言模型的知识覆盖面和答案准确性。
🏗️ RAG vs 传统LLM对比
让我们通过一个生动的对比来理解RAG的优势:
# RAG系统核心架构演示import numpy as npfrom typing import List, Dict, Tuple, Anyfrom abc import ABC, abstractmethodfrom dataclasses import dataclassfrom datetime import datetimeimport json@dataclassclass Document:"""文档数据结构"""id: strtitle: strcontent: strmetadata: Dict[str, Any]embedding: np.ndarray = Nonecreated_at: datetime = Nonedef __post_init__(self):if self.created_at is None:self.created_at = datetime.now()@dataclassclass RetrievalResult:"""检索结果数据结构"""document: Documentscore: floatrelevance_explanation: str@dataclassclass RAGResponse:"""RAG系统响应结构"""query: strretrieved_docs: List[RetrievalResult]generated_answer: strconfidence_score: floatsources: List[str]class TraditionalLLM:"""传统LLM模拟类"""def __init__(self, model_name: str):self.model_name = model_nameself.knowledge_cutoff = "2023-04" # 知识截止时间self.parameter_count = "175B" # 参数量def generate_answer(self, query: str) -> str:"""基于参数化知识生成答案"""# 模拟传统LLM的局限性limitations = ["知识截止于训练时间","无法获取最新信息","可能产生幻觉","无法提供信息来源"]return f"""传统LLM回答:{query}基于我的训练数据(截止到{self.knowledge_cutoff}),我认为...注意:此答案基于预训练知识,可能不是最新信息。局限性:{', '.join(limitations)}"""class RAGSystem:"""RAG系统核心类"""def __init__(self, llm_model: str, vector_db_config: Dict):self.llm_model = llm_modelself.vector_db = None # 向量数据库连接self.embedding_model = None # 嵌入模型self.documents: List[Document] = []self.retrieval_top_k = 5def add_documents(self, documents: List[Document]):"""添加文档到知识库"""for doc in documents:# 生成文档嵌入向量doc.embedding = self._generate_embedding(doc.content)self.documents.append(doc)print(f"✅ 已添加 {len(documents)} 个文档到知识库")def _generate_embedding(self, text: str) -> np.ndarray:"""生成文本嵌入向量(模拟)"""# 这里应该调用真实的嵌入模型return np.random.random(768) # 模拟768维向量def retrieve_relevant_docs(self, query: str, top_k: int = None) -> List[RetrievalResult]:"""检索相关文档"""if top_k is None:top_k = self.retrieval_top_k# 生成查询嵌入query_embedding = self._generate_embedding(query)# 计算相似度分数results = []for doc in self.documents:# 使用余弦相似度(模拟)similarity = np.random.random() # 模拟相似度计算result = RetrievalResult(document=doc,score=similarity,relevance_explanation=f"与查询在语义上相关度为 {similarity:.3f}")results.append(result)# 按相似度排序并返回top_kresults.sort(key=lambda x: x.score, reverse=True)return results[:top_k]def generate_answer(self, query: str) -> RAGResponse:"""生成RAG增强回答"""# 步骤1:检索相关文档retrieved_docs = self.retrieve_relevant_docs(query)# 步骤2:构建增强上下文context_parts = []sources = []for i, result in enumerate(retrieved_docs):doc = result.documentcontext_parts.append(f"参考文档{i+1}:{doc.content[:200]}...")sources.append(f"{doc.title} (相关度: {result.score:.3f})")enhanced_context = "\n\n".join(context_parts)# 步骤3:生成增强回答enhanced_prompt = f"""基于以下检索到的相关文档,回答用户问题:用户问题:{query}相关文档:{enhanced_context}请基于上述文档内容给出准确、有根据的回答,并在回答中引用具体来源。"""# 模拟LLM生成过程generated_answer = f"""基于检索到的相关文档,我可以为您提供以下回答:[基于文档内容的详细回答...]该回答基于 {len(retrieved_docs)} 个相关文档,具有较高的可信度。"""return RAGResponse(query=query,retrieved_docs=retrieved_docs,generated_answer=generated_answer,confidence_score=0.85, # 模拟置信度sources=sources)# 系统对比演示def compare_systems():"""对比传统LLM和RAG系统"""print("🔍 AI问答系统对比演示")print("=" * 50)# 创建系统实例traditional_llm = TraditionalLLM("GPT-3.5")rag_system = RAGSystem("GPT-3.5", {"type": "faiss"})# 添加一些示例文档到RAG系统sample_docs = [Document(id="doc1",title="Python RAG技术白皮书",content="RAG技术通过结合检索和生成,显著提升了AI系统的知识覆盖面...",metadata={"category": "技术文档", "date": "2024-12"}),Document(id="doc2",title="向量数据库最佳实践",content="向量数据库是RAG系统的核心组件,负责高效存储和检索语义向量...",metadata={"category": "技术指南", "date": "2024-11"})]rag_system.add_documents(sample_docs)# 测试查询query = "什么是RAG技术,它有什么优势?"print(f"\n📝 用户问题:{query}")print("\n" + "="*50)# 传统LLM回答print("🤖 传统LLM回答:")traditional_answer = traditional_llm.generate_answer(query)print(traditional_answer)print("\n" + "="*50)# RAG系统回答print("🔍 RAG系统回答:")rag_response = rag_system.generate_answer(query)print(f"生成的回答:{rag_response.generated_answer}")print(f"置信度:{rag_response.confidence_score}")print(f"参考来源:{', '.join(rag_response.sources)}")# 运行对比演示if __name__ == "__main__":compare_systems()print("✅ RAG系统核心架构演示完成")
🎯 RAG系统的核心优势
通过上面的对比演示,我们可以清楚地看到RAG系统的显著优势:
🏗️ RAG系统工作流程详解
让我们深入了解RAG系统的完整工作流程:
class RAGWorkflowDemo:"""RAG工作流程演示类"""def __init__(self):self.workflow_steps = ["文档预处理","向量化编码","向量存储","查询处理","相似度检索","上下文构建","答案生成","质量验证"]def demonstrate_workflow(self, user_query: str):"""演示完整的RAG工作流程"""print("🔄 RAG系统工作流程演示")print("=" * 60)# 步骤1:文档预处理(离线阶段)print("\n📚 步骤1:文档预处理(离线阶段)")print("- 收集和清洗文档数据")print("- 文档分块和格式标准化")print("- 提取元数据信息")raw_documents = ["RAG技术是一种结合检索和生成的AI架构...","向量数据库用于存储和检索高维向量数据...","语义搜索通过理解查询意图提供精准结果..."]processed_chunks = []for i, doc in enumerate(raw_documents):chunk = {"id": f"chunk_{i}","content": doc,"length": len(doc),"metadata": {"source": f"document_{i}.txt"}}processed_chunks.append(chunk)print(f"✅ 处理完成:{len(processed_chunks)} 个文档块")# 步骤2:向量化编码print("\n🧮 步骤2:向量化编码")print("- 使用预训练嵌入模型编码文档")print("- 生成高维语义向量表示")embeddings = []for chunk in processed_chunks:# 模拟向量化过程embedding = np.random.random(384) # 384维向量embeddings.append(embedding)print(f" 文档块 {chunk['id']}: 向量维度 {len(embedding)}")print(f"✅ 向量化完成:{len(embeddings)} 个向量")# 步骤3:向量存储print("\n💾 步骤3:向量存储")print("- 将向量存储到向量数据库")print("- 建立高效的索引结构")vector_index = {"vectors": embeddings,"metadata": [chunk["metadata"] for chunk in processed_chunks],"index_type": "HNSW", # 层次化小世界图"dimension": 384}print(f"✅ 存储完成:{len(vector_index['vectors'])} 个向量已索引")# 步骤4:查询处理(在线阶段)print(f"\n❓ 步骤4:查询处理(在线阶段)")print(f"- 用户查询:{user_query}")print("- 查询预处理和标准化")processed_query = {"original": user_query,"cleaned": user_query.lower().strip(),"tokens": user_query.split(),"intent": "信息查询"}print(f"✅ 查询处理完成:{processed_query['intent']}")# 步骤5:相似度检索print("\n🔍 步骤5:相似度检索")print("- 将查询向量化")print("- 计算与文档向量的相似度")print("- 检索最相关的文档块")query_embedding = np.random.random(384) # 模拟查询向量similarities = []for i, doc_embedding in enumerate(embeddings):# 模拟余弦相似度计算similarity = np.random.random()similarities.append({"chunk_id": f"chunk_{i}","similarity": similarity,"content": processed_chunks[i]["content"][:50] + "..."})# 按相似度排序similarities.sort(key=lambda x: x["similarity"], reverse=True)top_results = similarities[:3] # 取前3个最相关的print("📊 检索结果(按相关性排序):")for i, result in enumerate(top_results):print(f" {i+1}. 相似度: {result['similarity']:.3f} | {result['content']}")# 步骤6:上下文构建print("\n📝 步骤6:上下文构建")print("- 整合检索到的相关文档")print("- 构建增强上下文")context_parts = []for result in top_results:context_parts.append(f"参考内容:{result['content']}")enhanced_context = "\n".join(context_parts)print(f"✅ 上下文构建完成:{len(context_parts)} 个参考文档")# 步骤7:答案生成print("\n🤖 步骤7:答案生成")print("- 结合查询和检索上下文")print("- 使用语言模型生成回答")generation_prompt = f"""基于以下检索到的相关内容,回答用户问题:用户问题:{user_query}相关内容:{enhanced_context}请提供准确、有根据的回答。"""# 模拟生成过程generated_answer = f"基于检索到的相关文档,{user_query}的答案是..."print(f"✅ 答案生成完成:{len(generated_answer)} 字符")# 步骤8:质量验证print("\n🔍 步骤8:质量验证")print("- 验证答案与检索内容的一致性")print("- 评估答案质量和可信度")quality_metrics = {"相关性评分": 0.87,"准确性评分": 0.92,"完整性评分": 0.85,"可信度评分": 0.89}print("📊 质量评估结果:")for metric, score in quality_metrics.items():print(f" {metric}: {score:.2f}")overall_score = sum(quality_metrics.values()) / len(quality_metrics)print(f"✅ 综合质量评分:{overall_score:.2f}")return {"query": user_query,"retrieved_docs": top_results,"generated_answer": generated_answer,"quality_score": overall_score}# 运行工作流程演示workflow_demo = RAGWorkflowDemo()result = workflow_demo.demonstrate_workflow("什么是RAG技术?")print("\n" + "=" * 60)print("🎉 RAG工作流程演示完成!")print(f"📋 最终结果:{result['generated_answer']}")print(f"🏆 质量评分:{result['quality_score']:.2f}")
🎯 RAG系统的技术架构层次
RAG系统可以分为以下几个核心技术层次:
通过这个全面的概述,我们建立了对RAG技术的基础理解。在接下来的章节中,我们将深入探讨每个技术组件的具体实现和优化策略。
本节我们学习了RAG系统的核心概念、工作原理和技术架构。下一节我们将深入学习文档处理与向量化技术,这是构建高质量RAG系统的基础。
29.2 文档处理与向量化技术
🏭 信息预处理工厂
在我们的知识检索中心中,信息预处理工厂是整个系统的起点。就像一座现代化的工厂,它负责将各种原始文档转化为标准化、结构化的知识单元。
📄 文档类型与格式处理
RAG系统需要处理多种类型的文档,让我们构建一个通用的文档处理器:
# 文档处理与向量化系统import osimport reimport jsonfrom typing import List, Dict, Any, Optional, Unionfrom dataclasses import dataclass, fieldfrom enum import Enumimport hashlibfrom pathlib import Pathclass DocumentType(Enum):"""文档类型枚举"""TEXT = "text"PDF = "pdf"WORD = "word"HTML = "html"MARKDOWN = "markdown"JSON = "json"CSV = "csv"@dataclassclass DocumentChunk:"""文档分块数据结构"""id: strcontent: strmetadata: Dict[str, Any]chunk_index: intparent_doc_id: strembedding: Optional[np.ndarray] = Nonedef __post_init__(self):# 生成内容哈希作为唯一标识if not self.id:content_hash = hashlib.md5(self.content.encode()).hexdigest()[:8]self.id = f"{self.parent_doc_id}_chunk_{self.chunk_index}_{content_hash}"class DocumentProcessor:"""通用文档处理器"""def __init__(self):self.supported_types = {'.txt': DocumentType.TEXT,'.md': DocumentType.MARKDOWN,'.pdf': DocumentType.PDF,'.docx': DocumentType.WORD,'.html': DocumentType.HTML,'.json': DocumentType.JSON,'.csv': DocumentType.CSV}# 文档处理统计self.processing_stats = {"total_docs": 0,"successful_docs": 0,"failed_docs": 0,"total_chunks": 0}def detect_document_type(self, file_path: str) -> DocumentType:"""检测文档类型"""file_extension = Path(file_path).suffix.lower()return self.supported_types.get(file_extension, DocumentType.TEXT)def extract_text_from_file(self, file_path: str) -> str:"""从文件中提取文本内容"""doc_type = self.detect_document_type(file_path)try:if doc_type == DocumentType.TEXT:return self._extract_from_text(file_path)elif doc_type == DocumentType.MARKDOWN:return self._extract_from_markdown(file_path)elif doc_type == DocumentType.PDF:return self._extract_from_pdf(file_path)elif doc_type == DocumentType.HTML:return self._extract_from_html(file_path)elif doc_type == DocumentType.JSON:return self._extract_from_json(file_path)else:# 默认按文本处理return self._extract_from_text(file_path)except Exception as e:print(f"❌ 处理文件 {file_path} 时出错: {str(e)}")return ""def _extract_from_text(self, file_path: str) -> str:"""提取纯文本文件内容"""with open(file_path, 'r', encoding='utf-8') as f:return f.read()def _extract_from_markdown(self, file_path: str) -> str:"""提取Markdown文件内容"""with open(file_path, 'r', encoding='utf-8') as f:content = f.read()# 移除Markdown标记,保留纯文本# 移除代码块content = re.sub(r'```[\s\S]*?```', '', content)# 移除内联代码content = re.sub(r'`[^`]*`', '', content)# 移除链接但保留文本content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content)# 移除标题标记content = re.sub(r'^#+\s*', '', content, flags=re.MULTILINE)# 移除粗体和斜体标记content = re.sub(r'\*\*([^*]+)\*\*', r'\1', content)content = re.sub(r'\*([^*]+)\*', r'\1', content)return content.strip()def _extract_from_pdf(self, file_path: str) -> str:"""提取PDF文件内容(需要PyPDF2库)"""try:import PyPDF2with open(file_path, 'rb') as file:pdf_reader = PyPDF2.PdfReader(file)text_content = []for page in pdf_reader.pages:text_content.append(page.extract_text())return '\n'.join(text_content)except ImportError:print("⚠️ 需要安装PyPDF2库来处理PDF文件: pip install PyPDF2")return ""def _extract_from_html(self, file_path: str) -> str:"""提取HTML文件内容(需要BeautifulSoup库)"""try:from bs4 import BeautifulSoupwith open(file_path, 'r', encoding='utf-8') as f:soup = BeautifulSoup(f.read(), 'html.parser')# 移除script和style标签for script in soup(["script", "style"]):script.decompose()return soup.get_text()except ImportError:print("⚠️ 需要安装BeautifulSoup4库来处理HTML文件: pip install beautifulsoup4")return ""def _extract_from_json(self, file_path: str) -> str:"""提取JSON文件内容"""with open(file_path, 'r', encoding='utf-8') as f:data = json.load(f)# 递归提取所有文本值def extract_text_values(obj):if isinstance(obj, str):return [obj]elif isinstance(obj, dict):texts = []for value in obj.values():texts.extend(extract_text_values(value))return textselif isinstance(obj, list):texts = []for item in obj:texts.extend(extract_text_values(item))return textselse:return [str(obj)]text_values = extract_text_values(data)return ' '.join(text_values)def clean_text(self, text: str) -> str:"""清洗文本内容"""if not text:return ""# 移除多余的空白字符text = re.sub(r'\s+', ' ', text)# 移除特殊字符(保留基本标点)text = re.sub(r'[^\w\s\.\,\!\?\;\:\-\(\)]', '', text)# 移除过短的行lines = text.split('\n')cleaned_lines = [line.strip() for line in lines if len(line.strip()) > 10]return '\n'.join(cleaned_lines).strip()# 文档分块策略class ChunkingStrategy:"""文档分块策略基类"""def chunk_document(self, text: str, metadata: Dict) -> List[DocumentChunk]:raise NotImplementedErrorclass FixedSizeChunking(ChunkingStrategy):"""固定大小分块策略"""def __init__(self, chunk_size: int = 1000, overlap: int = 200):self.chunk_size = chunk_sizeself.overlap = overlapdef chunk_document(self, text: str, metadata: Dict) -> List[DocumentChunk]:"""按固定大小分块"""chunks = []doc_id = metadata.get('doc_id', 'unknown')# 计算分块位置start = 0chunk_index = 0while start < len(text):end = min(start + self.chunk_size, len(text))# 尝试在单词边界处分割if end < len(text):# 向前查找最近的空格while end > start and text[end] != ' ':end -= 1if end == start: # 如果没找到空格,使用原始位置end = min(start + self.chunk_size, len(text))chunk_text = text[start:end].strip()if chunk_text: # 只添加非空分块chunk = DocumentChunk(id="", # 将在__post_init__中生成content=chunk_text,metadata={**metadata,'chunk_method': 'fixed_size','chunk_size': len(chunk_text),'start_pos': start,'end_pos': end},chunk_index=chunk_index,parent_doc_id=doc_id)chunks.append(chunk)chunk_index += 1# 计算下一个分块的起始位置(考虑重叠)start = max(start + self.chunk_size - self.overlap, end)return chunksclass SemanticChunking(ChunkingStrategy):"""语义分块策略"""def __init__(self, max_chunk_size: int = 1500):self.max_chunk_size = max_chunk_sizedef chunk_document(self, text: str, metadata: Dict) -> List[DocumentChunk]:"""按语义边界分块"""chunks = []doc_id = metadata.get('doc_id', 'unknown')# 按段落分割paragraphs = text.split('\n\n')current_chunk = ""chunk_index = 0for paragraph in paragraphs:paragraph = paragraph.strip()if not paragraph:continue# 如果当前段落加上现有分块超过最大长度,先保存当前分块if current_chunk and len(current_chunk) + len(paragraph) > self.max_chunk_size:chunk = DocumentChunk(id="",content=current_chunk.strip(),metadata={**metadata,'chunk_method': 'semantic','chunk_size': len(current_chunk),'paragraph_count': current_chunk.count('\n\n') + 1},chunk_index=chunk_index,parent_doc_id=doc_id)chunks.append(chunk)chunk_index += 1current_chunk = ""# 添加当前段落if current_chunk:current_chunk += "\n\n" + paragraphelse:current_chunk = paragraph# 添加最后一个分块if current_chunk.strip():chunk = DocumentChunk(id="",content=current_chunk.strip(),metadata={**metadata,'chunk_method': 'semantic','chunk_size': len(current_chunk),'paragraph_count': current_chunk.count('\n\n') + 1},chunk_index=chunk_index,parent_doc_id=doc_id)chunks.append(chunk)return chunks# 文档处理管道演示def demonstrate_document_processing():"""演示文档处理流程"""print("📄 文档处理与分块演示")print("=" * 50)# 创建文档处理器processor = DocumentProcessor()# 模拟文档内容sample_documents = {"tech_doc.md": """# RAG技术详解## 什么是RAG检索增强生成(RAG)是一种将信息检索与文本生成相结合的AI技术。它通过在生成过程中动态检索相关知识,显著提升了语言模型的准确性。## RAG的优势1. **实时知识更新**:无需重新训练模型即可更新知识库2. **减少幻觉**:基于真实文档生成答案3. **可追溯性**:提供答案来源,增强可信度## 技术架构RAG系统通常包含以下组件:- 文档处理器- 向量数据库- 检索器- 生成器这些组件协同工作,实现高质量的问答系统。""","user_manual.txt": """用户手册第一章:系统介绍本系统是一个基于RAG技术的智能问答平台。用户可以上传文档,系统会自动建立知识库,然后回答相关问题。第二章:使用方法1. 上传文档到系统2. 等待文档处理完成3. 在问答界面提出问题4. 系统会基于文档内容给出答案第三章:注意事项- 支持多种文档格式- 文档内容应该准确可靠- 系统会保护用户隐私"""}# 处理每个文档all_chunks = []for filename, content in sample_documents.items():print(f"\n📝 处理文档:{filename}")print(f"原始长度:{len(content)} 字符")# 清洗文本cleaned_content = processor.clean_text(content)print(f"清洗后长度:{len(cleaned_content)} 字符")# 创建文档元数据doc_metadata = {'doc_id': filename.replace('.', '_'),'filename': filename,'original_length': len(content),'cleaned_length': len(cleaned_content),'processing_time': datetime.now().isoformat()}# 测试不同的分块策略print("\n🔪 分块策略对比:")# 固定大小分块fixed_chunker = FixedSizeChunking(chunk_size=300, overlap=50)fixed_chunks = fixed_chunker.chunk_document(cleaned_content, doc_metadata)print(f" 固定大小分块:{len(fixed_chunks)} 个分块")# 语义分块semantic_chunker = SemanticChunking(max_chunk_size=400)semantic_chunks = semantic_chunker.chunk_document(cleaned_content, doc_metadata)print(f" 语义分块:{len(semantic_chunks)} 个分块")# 显示分块详情print("\n📊 分块详情(语义分块):")for i, chunk in enumerate(semantic_chunks[:3]): # 只显示前3个print(f" 分块 {i+1}:")print(f" ID: {chunk.id}")print(f" 长度: {len(chunk.content)} 字符")print(f" 内容预览: {chunk.content[:100]}...")print(f" 元数据: {chunk.metadata}")all_chunks.extend(semantic_chunks)# 更新处理统计processor.processing_stats["total_docs"] += 1processor.processing_stats["successful_docs"] += 1processor.processing_stats["total_chunks"] += len(semantic_chunks)# 显示处理统计print(f"\n📈 处理统计:")print(f" 总文档数: {processor.processing_stats['total_docs']}")print(f" 成功处理: {processor.processing_stats['successful_docs']}")print(f" 总分块数: {processor.processing_stats['total_chunks']}")print(f" 平均每文档分块数: {processor.processing_stats['total_chunks'] / processor.processing_stats['total_docs']:.1f}")return all_chunks# 运行文档处理演示processed_chunks = demonstrate_document_processing()print("\n✅ 文档处理演示完成")
🧮 向 量化技术深入
文档分块完成后,下一步是将文本转换为向量表示。这是RAG系统的核心技术之一:
# 向量化技术实现from abc import ABC, abstractmethodimport numpy as npfrom typing import List, Dict, Optionalfrom sklearn.feature_extraction.text import TfidfVectorizerfrom sklearn.metrics.pairwise import cosine_similarityclass EmbeddingModel(ABC):"""嵌入模型抽象基类"""@abstractmethoddef encode(self, texts: List[str]) -> np.ndarray:"""将文本编码为向量"""pass@abstractmethoddef get_dimension(self) -> int:"""获取向量维度"""passclass TFIDFEmbedding(EmbeddingModel):"""基于TF-IDF的嵌入模型"""def __init__(self, max_features: int = 10000):self.vectorizer = TfidfVectorizer(max_features=max_features,stop_words='english',ngram_range=(1, 2))self.is_fitted = Falseself.dimension = max_featuresdef fit(self, texts: List[str]):"""训练TF-IDF模型"""self.vectorizer.fit(texts)self.is_fitted = True# 更新实际维度self.dimension = len(self.vectorizer.vocabulary_)def encode(self, texts: List[str]) -> np.ndarray:"""编码文本为TF-IDF向量"""if not self.is_fitted:self.fit(texts)vectors = self.vectorizer.transform(texts)return vectors.toarray()def get_dimension(self) -> int:return self.dimensionclass SimulatedTransformerEmbedding(EmbeddingModel):"""模拟Transformer嵌入模型(如BERT、Sentence-BERT等)"""def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", dimension: int = 384):self.model_name = model_nameself.dimension = dimensionprint(f"🤖 初始化模拟嵌入模型: {model_name} (维度: {dimension})")def encode(self, texts: List[str]) -> np.ndarray:"""编码文本为密集向量(模拟)"""# 在实际应用中,这里会调用真实的Transformer模型# 例如使用sentence-transformers库embeddings = []for text in texts:# 模拟基于文本内容的向量生成# 实际实现会使用预训练的Transformer模型np.random.seed(hash(text) % 2**32) # 基于文本内容的确定性随机embedding = np.random.normal(0, 1, self.dimension)# 归一化向量embedding = embedding / np.linalg.norm(embedding)embeddings.append(embedding)return np.array(embeddings)def get_dimension(self) -> int:return self.dimensionclass VectorDatabase:"""向量数据库实现"""def __init__(self, embedding_model: EmbeddingModel):self.embedding_model = embedding_modelself.vectors: np.ndarray = Noneself.metadata: List[Dict] = []self.index_to_chunk_id: Dict[int, str] = {}self.chunk_id_to_index: Dict[str, int] = {}def add_documents(self, chunks: List[DocumentChunk]):"""添加文档分块到向量数据库"""print(f"\n💾 向量数据库添加文档分块")print(f"待添加分块数: {len(chunks)}")# 提取文本内容texts = [chunk.content for chunk in chunks]# 生成向量print("🧮 生成向量嵌入...")new_vectors = self.embedding_model.encode(texts)# 存储向量和元数据if self.vectors is None:self.vectors = new_vectorselse:self.vectors = np.vstack([self.vectors, new_vectors])# 更新索引映射start_index = len(self.metadata)for i, chunk in enumerate(chunks):index = start_index + iself.index_to_chunk_id[index] = chunk.idself.chunk_id_to_index[chunk.id] = index# 存储元数据chunk_metadata = {'chunk_id': chunk.id,'content': chunk.content,'parent_doc_id': chunk.parent_doc_id,'chunk_index': chunk.chunk_index,'metadata': chunk.metadata,'vector_index': index}self.metadata.append(chunk_metadata)print(f"✅ 已添加 {len(chunks)} 个向量")print(f"📊 数据库统计:")print(f" 总向量数: {len(self.vectors)}")print(f" 向量维度: {self.vectors.shape[1]}")print(f" 存储大小: {self.vectors.nbytes / 1024 / 1024:.2f} MB")def search(self, query: str, top_k: int = 5) -> List[Dict]:"""搜索相似向量"""if self.vectors is None or len(self.vectors) == 0:return []# 将查询转换为向量query_vector = self.embedding_model.encode([query])[0]# 计算相似度similarities = cosine_similarity([query_vector], self.vectors)[0]# 获取top_k结果top_indices = np.argsort(similarities)[::-1][:top_k]results = []for i, index in enumerate(top_indices):result = {'rank': i + 1,'chunk_id': self.index_to_chunk_id[index],'similarity': float(similarities[index]),'content': self.metadata[index]['content'],'metadata': self.metadata[index]['metadata'],'parent_doc_id': self.metadata[index]['parent_doc_id']}results.append(result)return resultsdef get_statistics(self) -> Dict:"""获取数据库统计信息"""if self.vectors is None:return {"total_vectors": 0, "dimension": 0, "storage_mb": 0}return {"total_vectors": len(self.vectors),"dimension": self.vectors.shape[1],"storage_mb": self.vectors.nbytes / 1024 / 1024,"total_chunks": len(self.metadata),"unique_documents": len(set(meta['parent_doc_id'] for meta in self.metadata))}# 向量化演示def demonstrate_vectorization():"""演示向量化过程"""print("🧮 向量化技术演示")print("=" * 50)# 使用之前处理的文档分块chunks = processed_chunks# 测试不同的嵌入模型embedding_models = {"TF-IDF": TFIDFEmbedding(max_features=1000),"Simulated-BERT": SimulatedTransformerEmbedding("sentence-bert", 384)}for model_name, embedding_model in embedding_models.items():print(f"\n🤖 测试嵌入模型: {model_name}")print(f"向量维度: {embedding_model.get_dimension()}")# 创建向量数据库vector_db = VectorDatabase(embedding_model)# 添加文档vector_db.add_documents(chunks)# 显示统计信息stats = vector_db.get_statistics()print(f"📊 数据库统计: {stats}")# 测试搜索test_queries = ["什么是RAG技术?","如何使用系统?","文档处理方法"]for query in test_queries:print(f"\n🔍 搜索查询: {query}")results = vector_db.search(query, top_k=3)for result in results:print(f" 排名 {result['rank']}: 相似度 {result['similarity']:.3f}")print(f" 文档: {result['parent_doc_id']}")print(f" 内容: {result['content'][:100]}...")print(f"\n✅ {model_name} 模型测试完成")# 运行向量化演示demonstrate_vectorization()print("\n🎉 向量化技术演示完成")
🎯 向量化质量评估
为了确保向量化的质量,我们需要建立评估机制:
通过本节的学习,我们深入了解了RAG系统中文档处理和向量化的核心技术。这些技术为高质量的信息检索奠定了坚实基础。
本节我们学习了文档处理、分块策略和向量化技术。下一节我们将探讨向量数据库技术,了解如何高效存储和检索大规模向量数据。
29.3 向量数据库技术
🏛️ 语义存储仓库
在我们的知识检索中心中,语义存储仓库是整个系统的核心基础设施。就像一座高科技的立体仓库,它不仅要存储海量的向量数据,还要支持高速的相似度检索和实时的数据更新。
🗄️ 向量数据库核心概念
向量数据库是专门为存储和检索高维向量数据而设计的数据库系统。与传统的关系型数据库不同,它优化了向量相似度计算和近似最近邻搜索。
# 向量数据库核心技术实现import numpy as npimport jsonimport pickleimport sqlite3from typing import List, Dict, Any, Optional, Tuplefrom dataclasses import dataclass, asdictfrom datetime import datetimeimport threadingimport timefrom abc import ABC, abstractmethodfrom enum import Enumclass IndexType(Enum):"""索引类型枚举"""FLAT = "flat" # 暴力搜索IVF = "ivf" # 倒排文件索引HNSW = "hnsw" # 层次化小世界图LSH = "lsh" # 局部敏感哈希ANNOY = "annoy" # Annoy树索引class DistanceMetric(Enum):"""距离度量枚举"""COSINE = "cosine" # 余弦相似度EUCLIDEAN = "euclidean" # 欧几里得距离DOT_PRODUCT = "dot_product" # 点积MANHATTAN = "manhattan" # 曼哈顿距离@dataclassclass VectorRecord:"""向量记录数据结构"""id: strvector: np.ndarraymetadata: Dict[str, Any]timestamp: datetimedef to_dict(self) -> Dict:"""转换为字典格式"""return {'id': self.id,'vector': self.vector.tolist(),'metadata': self.metadata,'timestamp': self.timestamp.isoformat()}@classmethoddef from_dict(cls, data: Dict) -> 'VectorRecord':"""从字典创建记录"""return cls(id=data['id'],vector=np.array(data['vector']),metadata=data['metadata'],timestamp=datetime.fromisoformat(data['timestamp']))class VectorIndex(ABC):"""向量索引抽象基类"""def __init__(self, dimension: int, metric: DistanceMetric = DistanceMetric.COSINE):self.dimension = dimensionself.metric = metricself.is_trained = False@abstractmethoddef add_vectors(self, vectors: np.ndarray, ids: List[str]):"""添加向量到索引"""pass@abstractmethoddef search(self, query_vector: np.ndarray, k: int) -> Tuple[List[str], List[float]]:"""搜索最相似的k个向量"""pass@abstractmethoddef remove_vector(self, vector_id: str) -> bool:"""从索引中移除向量"""passclass FlatIndex(VectorIndex):"""暴力搜索索引实现"""def __init__(self, dimension: int, metric: DistanceMetric = DistanceMetric.COSINE):super().__init__(dimension, metric)self.vectors: np.ndarray = Noneself.ids: List[str] = []self.id_to_index: Dict[str, int] = {}def add_vectors(self, vectors: np.ndarray, ids: List[str]):"""添加向量到索引"""if vectors.shape[1] != self.dimension:raise ValueError(f"向量维度不匹配: 期望 {self.dimension}, 实际 {vectors.shape[1]}")if self.vectors is None:self.vectors = vectors.copy()else:self.vectors = np.vstack([self.vectors, vectors])# 更新ID映射start_index = len(self.ids)for i, vector_id in enumerate(ids):self.id_to_index[vector_id] = start_index + iself.ids.extend(ids)self.is_trained = Truedef search(self, query_vector: np.ndarray, k: int) -> Tuple[List[str], List[float]]:"""搜索最相似的k个向量"""if not self.is_trained or self.vectors is None:return [], []# 计算相似度if self.metric == DistanceMetric.COSINE:# 余弦相似度query_norm = query_vector / np.linalg.norm(query_vector)vectors_norm = self.vectors / np.linalg.norm(self.vectors, axis=1, keepdims=True)similarities = np.dot(vectors_norm, query_norm)# 转换为距离(距离越小越相似)distances = 1 - similaritieselif self.metric == DistanceMetric.EUCLIDEAN:# 欧几里得距离distances = np.linalg.norm(self.vectors - query_vector, axis=1)else:raise NotImplementedError(f"距离度量 {self.metric} 暂未实现")# 获取top-k结果k = min(k, len(self.ids))top_indices = np.argpartition(distances, k)[:k]top_indices = top_indices[np.argsort(distances[top_indices])]result_ids = [self.ids[i] for i in top_indices]result_distances = distances[top_indices].tolist()return result_ids, result_distancesdef remove_vector(self, vector_id: str) -> bool:"""从索引中移除向量"""if vector_id not in self.id_to_index:return Falseindex = self.id_to_index[vector_id]# 删除向量self.vectors = np.delete(self.vectors, index, axis=0)# 更新ID列表和映射del self.ids[index]del self.id_to_index[vector_id]# 重新构建索引映射self.id_to_index = {id_: i for i, id_ in enumerate(self.ids)}return Trueclass HNSWIndex(VectorIndex):"""HNSW索引实现(简化版)"""def __init__(self, dimension: int, metric: DistanceMetric = DistanceMetric.COSINE,max_connections: int = 16, ef_construction: int = 200):super().__init__(dimension, metric)self.max_connections = max_connectionsself.ef_construction = ef_constructionself.vectors: Dict[str, np.ndarray] = {}self.graph: Dict[str, List[str]] = {}self.entry_point: Optional[str] = Nonedef add_vectors(self, vectors: np.ndarray, ids: List[str]):"""添加向量到HNSW图"""for vector, vector_id in zip(vectors, ids):self._add_single_vector(vector, vector_id)self.is_trained = Truedef _add_single_vector(self, vector: np.ndarray, vector_id: str):"""添加单个向量到图中"""self.vectors[vector_id] = vectorself.graph[vector_id] = []if self.entry_point is None:self.entry_point = vector_idreturn# 简化的HNSW插入逻辑# 在实际实现中,这里会有更复杂的层次结构candidates = self._search_layer(vector, self.ef_construction)# 连接到最近的邻居connections = min(len(candidates), self.max_connections)for i in range(connections):neighbor_id = candidates[i][1]# 双向连接if neighbor_id not in self.graph[vector_id]:self.graph[vector_id].append(neighbor_id)if vector_id not in self.graph[neighbor_id]:self.graph[neighbor_id].append(vector_id)# 修剪连接(保持度数限制)if len(self.graph[neighbor_id]) > self.max_connections:self._prune_connections(neighbor_id)def _search_layer(self, query_vector: np.ndarray, ef: int) -> List[Tuple[float, str]]:"""在图层中搜索"""if not self.vectors:return []visited = set()candidates = []# 从入口点开始if self.entry_point:dist = self._calculate_distance(query_vector, self.vectors[self.entry_point])candidates.append((dist, self.entry_point))visited.add(self.entry_point)# 贪心搜索for _ in range(ef):if not candidates:breakcandidates.sort()current_dist, current_id = candidates.pop(0)# 检查邻居for neighbor_id in self.graph.get(current_id, []):if neighbor_id not in visited:visited.add(neighbor_id)dist = self._calculate_distance(query_vector, self.vectors[neighbor_id])candidates.append((dist, neighbor_id))candidates.sort()return candidatesdef _calculate_distance(self, v1: np.ndarray, v2: np.ndarray) -> float:"""计算两个向量之间的距离"""if self.metric == DistanceMetric.COSINE:return 1 - np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))elif self.metric == DistanceMetric.EUCLIDEAN:return np.linalg.norm(v1 - v2)else:raise NotImplementedError(f"距离度量 {self.metric} 暂未实现")def _prune_connections(self, vector_id: str):"""修剪连接以保持度数限制"""if len(self.graph[vector_id]) <= self.max_connections:return# 简化的修剪策略:保留距离最近的连接vector = self.vectors[vector_id]connections = self.graph[vector_id]# 计算到所有邻居的距离distances = []for neighbor_id in connections:dist = self._calculate_distance(vector, self.vectors[neighbor_id])distances.append((dist, neighbor_id))# 保留最近的邻居distances.sort()new_connections = [neighbor_id for _, neighbor_id in distances[:self.max_connections]]self.graph[vector_id] = new_connectionsdef search(self, query_vector: np.ndarray, k: int) -> Tuple[List[str], List[float]]:"""搜索最相似的k个向量"""if not self.is_trained:return [], []candidates = self._search_layer(query_vector, max(self.ef_construction, k))# 返回top-k结果k = min(k, len(candidates))top_candidates = candidates[:k]result_ids = [candidate[1] for candidate in top_candidates]result_distances = [candidate[0] for candidate in top_candidates]return result_ids, result_distancesdef remove_vector(self, vector_id: str) -> bool:"""从索引中移除向量"""if vector_id not in self.vectors:return False# 移除所有连接for neighbor_id in self.graph.get(vector_id, []):if vector_id in self.graph[neighbor_id]:self.graph[neighbor_id].remove(vector_id)# 删除向量和图节点del self.vectors[vector_id]del self.graph[vector_id]# 更新入口点if self.entry_point == vector_id:self.entry_point = next(iter(self.vectors.keys())) if self.vectors else Nonereturn Trueclass AdvancedVectorDatabase:"""高级向量数据库实现"""def __init__(self, dimension: int, index_type: IndexType = IndexType.HNSW,metric: DistanceMetric = DistanceMetric.COSINE,persist_path: Optional[str] = None):self.dimension = dimensionself.index_type = index_typeself.metric = metricself.persist_path = persist_path# 创建索引if index_type == IndexType.FLAT:self.index = FlatIndex(dimension, metric)elif index_type == IndexType.HNSW:self.index = HNSWIndex(dimension, metric)else:raise NotImplementedError(f"索引类型 {index_type} 暂未实现")# 元数据存储self.metadata_store: Dict[str, Dict] = {}# 统计信息self.stats = {"total_vectors": 0,"total_searches": 0,"total_inserts": 0,"total_deletes": 0,"avg_search_time": 0.0}# 线程锁self._lock = threading.RLock()# 如果指定了持久化路径,尝试加载if persist_path:self.load_from_disk()def insert_vectors(self, records: List[VectorRecord]) -> bool:"""插入向量记录"""with self._lock:try:vectors = np.array([record.vector for record in records])ids = [record.id for record in records]# 添加到索引self.index.add_vectors(vectors, ids)# 存储元数据for record in records:self.metadata_store[record.id] = {'metadata': record.metadata,'timestamp': record.timestamp.isoformat()}# 更新统计self.stats["total_vectors"] += len(records)self.stats["total_inserts"] += len(records)print(f"✅ 成功插入 {len(records)} 个向量")return Trueexcept Exception as e:print(f"❌ 插入向量时出错: {str(e)}")return Falsedef search_vectors(self, query_vector: np.ndarray, k: int = 10,filter_metadata: Optional[Dict] = None) -> List[Dict]:"""搜索向量"""with self._lock:start_time = time.time()try:# 执行向量搜索result_ids, distances = self.index.search(query_vector, k * 2) # 获取更多结果用于过滤# 构建结果results = []for vector_id, distance in zip(result_ids, distances):if vector_id in self.metadata_store:metadata = self.metadata_store[vector_id]['metadata']# 应用元数据过滤if filter_metadata:if not self._match_filter(metadata, filter_metadata):continueresult = {'id': vector_id,'distance': distance,'similarity': 1 - distance if self.metric == DistanceMetric.COSINE else distance,'metadata': metadata,'timestamp': self.metadata_store[vector_id]['timestamp']}results.append(result)if len(results) >= k:break# 更新统计search_time = time.time() - start_timeself.stats["total_searches"] += 1self.stats["avg_search_time"] = ((self.stats["avg_search_time"] * (self.stats["total_searches"] - 1) + search_time)/ self.stats["total_searches"])return resultsexcept Exception as e:print(f"❌ 搜索向量时出错: {str(e)}")return []def _match_filter(self, metadata: Dict, filter_metadata: Dict) -> bool:"""检查元数据是否匹配过滤条件"""for key, value in filter_metadata.items():if key not in metadata:return Falseif isinstance(value, list):if metadata[key] not in value:return Falseelif metadata[key] != value:return Falsereturn Truedef delete_vector(self, vector_id: str) -> bool:"""删除向量"""with self._lock:try:# 从索引中删除if self.index.remove_vector(vector_id):# 删除元数据if vector_id in self.metadata_store:del self.metadata_store[vector_id]# 更新统计self.stats["total_vectors"] -= 1self.stats["total_deletes"] += 1print(f"✅ 成功删除向量: {vector_id}")return Trueelse:print(f"⚠️ 向量不存在: {vector_id}")return Falseexcept Exception as e:print(f"❌ 删除向量时出错: {str(e)}")return Falsedef get_statistics(self) -> Dict:"""获取数据库统计信息"""with self._lock:return {**self.stats,"index_type": self.index_type.value,"metric": self.metric.value,"dimension": self.dimension}def save_to_disk(self) -> bool:"""保存到磁盘"""if not self.persist_path:return Falsewith self._lock:try:# 准备保存数据save_data = {'dimension': self.dimension,'index_type': self.index_type.value,'metric': self.metric.value,'metadata_store': self.metadata_store,'stats': self.stats,'vectors': {},'index_data': {}}# 保存向量数据if hasattr(self.index, 'vectors') and self.index.vectors is not None:if isinstance(self.index.vectors, np.ndarray):save_data['vectors'] = {'data': self.index.vectors.tolist(),'ids': self.index.ids}elif isinstance(self.index.vectors, dict):save_data['vectors'] = {id_: vector.tolist() for id_, vector in self.index.vectors.items()}# 保存索引特定数据if hasattr(self.index, 'graph'):save_data['index_data']['graph'] = self.index.graphsave_data['index_data']['entry_point'] = self.index.entry_point# 写入文件with open(self.persist_path, 'w', encoding='utf-8') as f:json.dump(save_data, f, indent=2, ensure_ascii=False)print(f"✅ 数据库已保存到: {self.persist_path}")return Trueexcept Exception as e:print(f"❌ 保存数据库时出错: {str(e)}")return Falsedef load_from_disk(self) -> bool:"""从磁盘加载"""if not self.persist_path:return Falsetry:with open(self.persist_path, 'r', encoding='utf-8') as f:data = json.load(f)# 恢复元数据self.metadata_store = data.get('metadata_store', {})self.stats = data.get('stats', self.stats)# 恢复向量数据vectors_data = data.get('vectors', {})if vectors_data:if 'data' in vectors_data and 'ids' in vectors_data:# Flat索引格式vectors = np.array(vectors_data['data'])ids = vectors_data['ids']if len(vectors) > 0:self.index.add_vectors(vectors, ids)else:# HNSW索引格式for vector_id, vector_data in vectors_data.items():if isinstance(vector_data, list):vector = np.array(vector_data)self.index.add_vectors(vector.reshape(1, -1), [vector_id])# 恢复索引特定数据index_data = data.get('index_data', {})if hasattr(self.index, 'graph') and 'graph' in index_data:self.index.graph = index_data['graph']self.index.entry_point = index_data.get('entry_point')print(f"✅ 数据库已从磁盘加载: {self.persist_path}")return Trueexcept FileNotFoundError:print(f"⚠️ 持久化文件不存在,将创建新数据库: {self.persist_path}")return Falseexcept Exception as e:print(f"❌ 加载数据库时出错: {str(e)}")return False# 向量数据库演示def demonstrate_vector_database():"""演示向量数据库功能"""print("🗄️ 向量数据库技术演示")print("=" * 60)# 创建向量数据库实例databases = {"Flat索引": AdvancedVectorDatabase(dimension=384,index_type=IndexType.FLAT,persist_path="vector_db_flat.json"),"HNSW索引": AdvancedVectorDatabase(dimension=384,index_type=IndexType.HNSW,persist_path="vector_db_hnsw.json")}# 准备测试数据test_records = []categories = ["技术文档", "用户手册", "API文档", "教程", "FAQ"]for i in range(50):# 生成模拟向量np.random.seed(i)vector = np.random.normal(0, 1, 384)vector = vector / np.linalg.norm(vector) # 归一化record = VectorRecord(id=f"doc_{i:03d}",vector=vector,metadata={"title": f"文档_{i:03d}","category": categories[i % len(categories)],"length": np.random.randint(100, 2000),"author": f"作者_{i % 5}","tags": [f"tag_{j}" for j in range(i % 3 + 1)]},timestamp=datetime.now())test_records.append(record)# 测试每个数据库for db_name, db in databases.items():print(f"\n🔍 测试 {db_name}")print("-" * 40)# 插入数据print("📥 插入测试数据...")start_time = time.time()success = db.insert_vectors(test_records)insert_time = time.time() - start_timeif success:print(f"✅ 插入完成,耗时: {insert_time:.3f}秒")# 显示统计信息stats = db.get_statistics()print(f"📊 数据库统计: {stats}")# 测试搜索print("\n🔍 测试向量搜索...")# 创建查询向量np.random.seed(100)query_vector = np.random.normal(0, 1, 384)query_vector = query_vector / np.linalg.norm(query_vector)# 执行搜索start_time = time.time()results = db.search_vectors(query_vector, k=5)search_time = time.time() - start_timeprint(f"⏱️ 搜索耗时: {search_time:.3f}秒")print(f"📋 搜索结果:")for i, result in enumerate(results):print(f" {i+1}. ID: {result['id']}")print(f" 相似度: {result['similarity']:.4f}")print(f" 标题: {result['metadata']['title']}")print(f" 类别: {result['metadata']['category']}")# 测试带过滤的搜索print("\n🎯 测试元数据过滤搜索...")filtered_results = db.search_vectors(query_vector,k=3,filter_metadata={"category": "技术文档"})print(f"📋 过滤结果 (只显示技术文档):")for i, result in enumerate(filtered_results):print(f" {i+1}. {result['metadata']['title']} - {result['metadata']['category']}")# 测试删除print("\n🗑️ 测试向量删除...")delete_success = db.delete_vector("doc_000")if delete_success:print("✅ 删除成功")# 验证删除verify_results = db.search_vectors(test_records[0].vector, k=5)found_deleted = any(r['id'] == 'doc_000' for r in verify_results)print(f"🔍 删除验证: {'❌ 仍然存在' if found_deleted else '✅ 已删除'}")# 保存到磁盘print("\n💾 测试持久化...")save_success = db.save_to_disk()if save_success:print("✅ 保存成功")print(f"\n🎉 向量数据库演示完成!")# 运行向量数据库演示demonstrate_vector_database()print("\n✅ 向量数据库技术演示完成")
🏗️ 向量数据库架构对比
不同的向量数据库有各自的优势和适用场景:
🔧 向量数据库优化策略
为了提升向量数据库的性能,我们需要考虑多个维度的优化:
class VectorDatabaseOptimizer:"""向量数据库优化器"""def __init__(self, database: AdvancedVectorDatabase):self.database = databaseself.optimization_history = []def analyze_performance(self) -> Dict[str, Any]:"""分析数据库性能"""stats = self.database.get_statistics()analysis = {"performance_score": 0,"bottlenecks": [],"recommendations": []}# 分析搜索性能avg_search_time = stats.get("avg_search_time", 0)if avg_search_time > 0.1: # 100msanalysis["bottlenecks"].append("搜索延迟过高")analysis["recommendations"].append("考虑使用HNSW索引或增加索引参数")# 分析向量规模total_vectors = stats.get("total_vectors", 0)if total_vectors > 100000 and self.database.index_type == IndexType.FLAT:analysis["bottlenecks"].append("大规模数据使用暴力搜索")analysis["recommendations"].append("升级到近似最近邻索引(HNSW/IVF)")# 分析内存使用estimated_memory = total_vectors * self.database.dimension * 4 / (1024**2) # MBif estimated_memory > 1000: # 1GBanalysis["bottlenecks"].append("内存使用过高")analysis["recommendations"].append("考虑量化压缩或分布式存储")# 计算综合性能评分base_score = 100if avg_search_time > 0.05:base_score -= 20if total_vectors > 50000 and self.database.index_type == IndexType.FLAT:base_score -= 30if estimated_memory > 500:base_score -= 15analysis["performance_score"] = max(0, base_score)return analysisdef optimize_index_parameters(self) -> Dict[str, Any]:"""优化索引参数"""if self.database.index_type == IndexType.HNSW:return self._optimize_hnsw_parameters()elif self.database.index_type == IndexType.FLAT:return self._suggest_index_upgrade()else:return {"message": "当前索引类型暂不支持参数优化"}def _optimize_hnsw_parameters(self) -> Dict[str, Any]:"""优化HNSW参数"""stats = self.database.get_statistics()total_vectors = stats.get("total_vectors", 0)recommendations = {"current_params": {"max_connections": getattr(self.database.index, 'max_connections', 16),"ef_construction": getattr(self.database.index, 'ef_construction', 200)},"recommended_params": {},"reasoning": []}# 根据数据规模调整参数if total_vectors < 10000:recommendations["recommended_params"] = {"max_connections": 16,"ef_construction": 200}recommendations["reasoning"].append("小规模数据,使用默认参数即可")elif total_vectors < 100000:recommendations["recommended_params"] = {"max_connections": 32,"ef_construction": 400}recommendations["reasoning"].append("中等规模数据,增加连接数和构建参数")else:recommendations["recommended_params"] = {"max_connections": 64,"ef_construction": 800}recommendations["reasoning"].append("大规模数据,使用高性能参数")return recommendationsdef _suggest_index_upgrade(self) -> Dict[str, Any]:"""建议索引升级"""stats = self.database.get_statistics()total_vectors = stats.get("total_vectors", 0)if total_vectors > 10000:return {"suggestion": "升级到HNSW索引","reason": f"当前 {total_vectors} 个向量,HNSW索引可显著提升搜索速度","expected_improvement": "搜索速度提升10-100倍"}else:return {"suggestion": "保持当前索引","reason": "数据规模较小,暴力搜索已足够"}def benchmark_search_performance(self, num_queries: int = 100) -> Dict[str, float]:"""基准测试搜索性能"""print(f"🏃 开始性能基准测试 ({num_queries} 次查询)")# 生成随机查询向量query_vectors = []for i in range(num_queries):np.random.seed(i + 1000)vector = np.random.normal(0, 1, self.database.dimension)vector = vector / np.linalg.norm(vector)query_vectors.append(vector)# 测试不同k值的性能k_values = [1, 5, 10, 20]results = {}for k in k_values:times = []for query_vector in query_vectors:start_time = time.time()self.database.search_vectors(query_vector, k=k)search_time = time.time() - start_timetimes.append(search_time)avg_time = np.mean(times)std_time = np.std(times)results[f"k={k}"] = {"avg_time": avg_time,"std_time": std_time,"qps": 1.0 / avg_time if avg_time > 0 else 0}print(f" k={k}: 平均 {avg_time*1000:.2f}ms, QPS: {1.0/avg_time:.1f}")return results# 性能优化演示def demonstrate_optimization():"""演示性能优化"""print("⚡ 向量数据库性能优化演示")print("=" * 50)# 创建测试数据库db = AdvancedVectorDatabase(dimension=256,index_type=IndexType.HNSW)# 插入测试数据print("📥 插入测试数据...")test_records = []for i in range(1000):np.random.seed(i)vector = np.random.normal(0, 1, 256)vector = vector / np.linalg.norm(vector)record = VectorRecord(id=f"test_{i:04d}",vector=vector,metadata={"category": f"cat_{i%10}", "value": i},timestamp=datetime.now())test_records.append(record)db.insert_vectors(test_records)# 创建优化器optimizer = VectorDatabaseOptimizer(db)# 性能分析print("\n📊 性能分析:")analysis = optimizer.analyze_performance()print(f"性能评分: {analysis['performance_score']}/100")if analysis['bottlenecks']:print("⚠️ 发现的瓶颈:")for bottleneck in analysis['bottlenecks']:print(f" - {bottleneck}")if analysis['recommendations']:print("💡 优化建议:")for recommendation in analysis['recommendations']:print(f" - {recommendation}")# 参数优化建议print("\n🔧 索引参数优化:")param_optimization = optimizer.optimize_index_parameters()if 'recommended_params' in param_optimization:print(f"当前参数: {param_optimization['current_params']}")print(f"推荐参数: {param_optimization['recommended_params']}")for reason in param_optimization['reasoning']:print(f" - {reason}")# 性能基准测试print("\n🏃 性能基准测试:")benchmark_results = optimizer.benchmark_search_performance(50)print("✅ 优化演示完成")# 运行优化演示demonstrate_optimization()print("\n🎉 向量数据库技术完整演示结束")
通过本节的学习,我们深入了解了向量数据库的核心技术,包括不同索引类型的实现、性能优化策略和实际应用场景。这为构建高性能的RAG系统奠定了坚实的技术基础。
本节我们学习了向量数据库的核心技术和优化策略。下一节我们将探讨检索策略优化,了解如何提升检索精度和效率。
29.4 检索策略优化
🎯 智能检索调度中心
在我们的知识检索中心中,智能检索调度中心负责根据查询特征和业务需求,动态选择最优的检索策略。就像一个经验丰富的图书管理员,它知道如何快速找到最相关的信息。
🔍 检索策略核心概念
检索策略优化是RAG系统性能的关键因素,包括查询理解、检索方法选择、结果排序和后处理等多个环节。
# 检索策略优化实现import numpy as npimport reimport jsonfrom typing import List, Dict, Any, Optional, Tuple, Unionfrom dataclasses import dataclassfrom abc import ABC, abstractmethodfrom enum import Enumimport timefrom collections import defaultdict, Counterimport mathclass QueryType(Enum):"""查询类型枚举"""FACTUAL = "factual" # 事实性查询ANALYTICAL = "analytical" # 分析性查询PROCEDURAL = "procedural" # 程序性查询CREATIVE = "creative" # 创造性查询COMPARATIVE = "comparative" # 比较性查询class RetrievalStrategy(Enum):"""检索策略枚举"""SEMANTIC = "semantic" # 语义检索KEYWORD = "keyword" # 关键词检索HYBRID = "hybrid" # 混合检索HIERARCHICAL = "hierarchical" # 层次检索MULTI_QUERY = "multi_query" # 多查询检索@dataclassclass QueryAnalysis:"""查询分析结果"""original_query: strquery_type: QueryTypekeywords: List[str]entities: List[str]intent: strcomplexity_score: floatdomain: str@dataclassclass RetrievalResult:"""检索结果"""document_id: strcontent: strscore: floatmetadata: Dict[str, Any]retrieval_method: strclass QueryAnalyzer:"""查询分析器"""def __init__(self):# 预定义的查询模式self.query_patterns = {QueryType.FACTUAL: [r'\b(what|who|when|where|which)\b',r'\b(define|definition|meaning)\b',r'\b(is|are|was|were)\b.*\?'],QueryType.ANALYTICAL: [r'\b(why|how|analyze|explain|compare)\b',r'\b(reason|cause|effect|impact)\b',r'\b(relationship|correlation)\b'],QueryType.PROCEDURAL: [r'\b(how to|step|process|procedure)\b',r'\b(install|configure|setup|create)\b',r'\b(tutorial|guide|instruction)\b'],QueryType.CREATIVE: [r'\b(generate|create|design|build)\b',r'\b(idea|suggestion|recommendation)\b',r'\b(brainstorm|innovate)\b'],QueryType.COMPARATIVE: [r'\b(compare|versus|vs|difference)\b',r'\b(better|worse|best|worst)\b',r'\b(advantage|disadvantage|pros|cons)\b']}# 关键词提取模式self.keyword_patterns = [r'\b[A-Z][a-z]+(?:\s[A-Z][a-z]+)*\b', # 专有名词r'\b\w{4,}\b', # 长单词r'\b(?:API|HTTP|JSON|XML|SQL|AI|ML|DL)\b' # 技术术语]def analyze_query(self, query: str) -> QueryAnalysis:"""分析查询意图和特征"""# 查询类型识别query_type = self._identify_query_type(query)# 关键词提取keywords = self._extract_keywords(query)# 实体识别(简化版)entities = self._extract_entities(query)# 意图分析intent = self._analyze_intent(query, query_type)# 复杂度评分complexity_score = self._calculate_complexity(query)# 领域识别domain = self._identify_domain(query, keywords)return QueryAnalysis(original_query=query,query_type=query_type,keywords=keywords,entities=entities,intent=intent,complexity_score=complexity_score,domain=domain)def _identify_query_type(self, query: str) -> QueryType:"""识别查询类型"""query_lower = query.lower()type_scores = {}for query_type, patterns in self.query_patterns.items():score = 0for pattern in patterns:matches = len(re.findall(pattern, query_lower))score += matchestype_scores[query_type] = score# 返回得分最高的类型,默认为事实性查询if not type_scores or max(type_scores.values()) == 0:return QueryType.FACTUALreturn max(type_scores, key=type_scores.get)def _extract_keywords(self, query: str) -> List[str]:"""提取关键词"""keywords = set()for pattern in self.keyword_patterns:matches = re.findall(pattern, query)keywords.update(matches)# 过滤停用词(简化版)stop_words = {'the', 'is', 'at', 'which', 'on', 'and', 'or', 'but', 'in', 'with', 'a', 'an'}keywords = [kw for kw in keywords if kw.lower() not in stop_words]return sorted(list(keywords))def _extract_entities(self, query: str) -> List[str]:"""提取命名实体(简化版)"""# 简化的实体识别,主要识别大写开头的词组entities = re.findall(r'\b[A-Z][a-z]+(?:\s[A-Z][a-z]+)*\b', query)return list(set(entities))def _analyze_intent(self, query: str, query_type: QueryType) -> str:"""分析查询意图"""intent_mapping = {QueryType.FACTUAL: "获取事实信息",QueryType.ANALYTICAL: "深度分析理解",QueryType.PROCEDURAL: "获取操作指导",QueryType.CREATIVE: "生成创新内容",QueryType.COMPARATIVE: "比较分析选择"}return intent_mapping.get(query_type, "信息检索")def _calculate_complexity(self, query: str) -> float:"""计算查询复杂度"""factors = {'length': len(query.split()) / 20.0, # 长度因子'questions': query.count('?') * 0.2, # 问题数量'conjunctions': len(re.findall(r'\b(and|or|but|however|although)\b', query.lower())) * 0.3,'technical_terms': len(re.findall(r'\b(?:API|HTTP|JSON|XML|SQL|AI|ML|DL)\b', query)) * 0.4}complexity = sum(factors.values())return min(complexity, 1.0) # 限制在0-1范围内def _identify_domain(self, query: str, keywords: List[str]) -> str:"""识别查询领域"""domain_keywords = {'technology': ['API', 'HTTP', 'JSON', 'XML', 'SQL', 'database', 'server', 'code'],'business': ['market', 'sales', 'revenue', 'customer', 'business', 'strategy'],'science': ['research', 'study', 'experiment', 'data', 'analysis', 'theory'],'education': ['learn', 'teach', 'course', 'tutorial', 'education', 'training']}query_lower = query.lower()domain_scores = {}for domain, domain_kws in domain_keywords.items():score = 0for kw in domain_kws:if kw.lower() in query_lower:score += 1domain_scores[domain] = scoreif domain_scores and max(domain_scores.values()) > 0:return max(domain_scores, key=domain_scores.get)return 'general'class RetrievalStrategySelector:"""检索策略选择器"""def __init__(self):# 策略选择规则self.strategy_rules = {QueryType.FACTUAL: [RetrievalStrategy.SEMANTIC, RetrievalStrategy.KEYWORD],QueryType.ANALYTICAL: [RetrievalStrategy.HYBRID, RetrievalStrategy.HIERARCHICAL],QueryType.PROCEDURAL: [RetrievalStrategy.KEYWORD, RetrievalStrategy.HIERARCHICAL],QueryType.CREATIVE: [RetrievalStrategy.SEMANTIC, RetrievalStrategy.MULTI_QUERY],QueryType.COMPARATIVE: [RetrievalStrategy.HYBRID, RetrievalStrategy.MULTI_QUERY]}# 策略性能历史self.strategy_performance = defaultdict(list)def select_strategy(self, query_analysis: QueryAnalysis) -> List[RetrievalStrategy]:"""选择最优检索策略"""# 基于查询类型的基础策略base_strategies = self.strategy_rules.get(query_analysis.query_type,[RetrievalStrategy.SEMANTIC])# 根据复杂度调整策略if query_analysis.complexity_score > 0.7:# 高复杂度查询使用多策略if RetrievalStrategy.MULTI_QUERY not in base_strategies:base_strategies.append(RetrievalStrategy.MULTI_QUERY)# 根据历史性能调整best_strategies = self._get_best_performing_strategies(query_analysis)if best_strategies:# 结合历史最佳策略combined_strategies = list(set(base_strategies + best_strategies))return combined_strategies[:3] # 限制策略数量return base_strategiesdef _get_best_performing_strategies(self, query_analysis: QueryAnalysis) -> List[RetrievalStrategy]:"""获取历史表现最佳的策略"""domain_key = f"{query_analysis.domain}_{query_analysis.query_type.value}"if domain_key in self.strategy_performance:# 计算各策略的平均性能strategy_scores = defaultdict(list)for record in self.strategy_performance[domain_key]:strategy_scores[record['strategy']].append(record['score'])# 返回平均分最高的策略avg_scores = {strategy: np.mean(scores)for strategy, scores in strategy_scores.items()}sorted_strategies = sorted(avg_scores.items(), key=lambda x: x[1], reverse=True)return [strategy for strategy, _ in sorted_strategies[:2]]return []def record_performance(self, query_analysis: QueryAnalysis,strategy: RetrievalStrategy, score: float):"""记录策略性能"""domain_key = f"{query_analysis.domain}_{query_analysis.query_type.value}"self.strategy_performance[domain_key].append({'strategy': strategy,'score': score,'timestamp': time.time()})class AdvancedRetriever:"""高级检索器"""def __init__(self, vector_database, text_corpus: Dict[str, str]):self.vector_database = vector_databaseself.text_corpus = text_corpus # 文档ID到文本内容的映射self.query_analyzer = QueryAnalyzer()self.strategy_selector = RetrievalStrategySelector()# 构建关键词索引self.keyword_index = self._build_keyword_index()def _build_keyword_index(self) -> Dict[str, List[str]]:"""构建关键词倒排索引"""keyword_index = defaultdict(list)for doc_id, content in self.text_corpus.items():# 简单的关键词提取words = re.findall(r'\b\w+\b', content.lower())for word in set(words):if len(word) > 3: # 过滤短词keyword_index[word].append(doc_id)return dict(keyword_index)def retrieve(self, query: str, top_k: int = 10) -> List[RetrievalResult]:"""执行智能检索"""# 1. 查询分析query_analysis = self.query_analyzer.analyze_query(query)print(f"🔍 查询分析: {query_analysis.query_type.value} | 复杂度: {query_analysis.complexity_score:.2f}")# 2. 策略选择strategies = self.strategy_selector.select_strategy(query_analysis)print(f"📋 选择策略: {[s.value for s in strategies]}")# 3. 多策略检索all_results = []strategy_weights = self._calculate_strategy_weights(strategies, query_analysis)for strategy in strategies:strategy_results = self._execute_strategy(strategy, query, query_analysis, top_k * 2)# 应用策略权重weight = strategy_weights.get(strategy, 1.0)for result in strategy_results:result.score *= weightresult.retrieval_method = f"{strategy.value}(w={weight:.2f})"all_results.extend(strategy_results)# 4. 结果融合和重排序final_results = self._fuse_and_rerank(all_results, query_analysis, top_k)# 5. 记录性能(简化版)if final_results:avg_score = np.mean([r.score for r in final_results])for strategy in strategies:self.strategy_selector.record_performance(query_analysis, strategy, avg_score)return final_resultsdef _calculate_strategy_weights(self, strategies: List[RetrievalStrategy],query_analysis: QueryAnalysis) -> Dict[RetrievalStrategy, float]:"""计算策略权重"""weights = {}for strategy in strategies:if strategy == RetrievalStrategy.SEMANTIC:# 语义检索在分析性和创造性查询中权重更高if query_analysis.query_type in [QueryType.ANALYTICAL, QueryType.CREATIVE]:weights[strategy] = 1.2else:weights[strategy] = 1.0elif strategy == RetrievalStrategy.KEYWORD:# 关键词检索在事实性和程序性查询中权重更高if query_analysis.query_type in [QueryType.FACTUAL, QueryType.PROCEDURAL]:weights[strategy] = 1.2else:weights[strategy] = 0.8elif strategy == RetrievalStrategy.HYBRID:# 混合检索权重稳定weights[strategy] = 1.1else:weights[strategy] = 1.0return weightsdef _execute_strategy(self, strategy: RetrievalStrategy, query: str,query_analysis: QueryAnalysis, top_k: int) -> List[RetrievalResult]:"""执行具体的检索策略"""if strategy == RetrievalStrategy.SEMANTIC:return self._semantic_retrieval(query, top_k)elif strategy == RetrievalStrategy.KEYWORD:return self._keyword_retrieval(query_analysis.keywords, top_k)elif strategy == RetrievalStrategy.HYBRID:semantic_results = self._semantic_retrieval(query, top_k // 2)keyword_results = self._keyword_retrieval(query_analysis.keywords, top_k // 2)return semantic_results + keyword_resultselif strategy == RetrievalStrategy.HIERARCHICAL:return self._hierarchical_retrieval(query, query_analysis, top_k)elif strategy == RetrievalStrategy.MULTI_QUERY:return self._multi_query_retrieval(query, query_analysis, top_k)else:return self._semantic_retrieval(query, top_k)def _semantic_retrieval(self, query: str, top_k: int) -> List[RetrievalResult]:"""语义检索"""# 模拟向量检索# 在实际应用中,这里会调用向量数据库的搜索功能# 简化实现:随机选择一些文档并赋予相似度分数import randomrandom.seed(hash(query) % 1000)doc_ids = list(self.text_corpus.keys())selected_docs = random.sample(doc_ids, min(top_k, len(doc_ids)))results = []for doc_id in selected_docs:# 模拟相似度计算similarity = random.uniform(0.5, 0.95)result = RetrievalResult(document_id=doc_id,content=self.text_corpus[doc_id][:200] + "...",score=similarity,metadata={"method": "semantic"},retrieval_method="semantic")results.append(result)return sorted(results, key=lambda x: x.score, reverse=True)def _keyword_retrieval(self, keywords: List[str], top_k: int) -> List[RetrievalResult]:"""关键词检索"""doc_scores = defaultdict(float)for keyword in keywords:keyword_lower = keyword.lower()if keyword_lower in self.keyword_index:matching_docs = self.keyword_index[keyword_lower]for doc_id in matching_docs:# 计算TF-IDF风格的分数tf = self.text_corpus[doc_id].lower().count(keyword_lower)idf = math.log(len(self.text_corpus) / len(matching_docs))doc_scores[doc_id] += tf * idf# 归一化分数if doc_scores:max_score = max(doc_scores.values())for doc_id in doc_scores:doc_scores[doc_id] /= max_score# 选择top-k结果sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]results = []for doc_id, score in sorted_docs:result = RetrievalResult(document_id=doc_id,content=self.text_corpus[doc_id][:200] + "...",score=score,metadata={"method": "keyword", "matched_keywords": keywords},retrieval_method="keyword")results.append(result)return resultsdef _hierarchical_retrieval(self, query: str, query_analysis: QueryAnalysis,top_k: int) -> List[RetrievalResult]:"""层次检索:先粗检索,再精检索"""# 第一层:粗检索,获取较多候选coarse_results = self._semantic_retrieval(query, top_k * 3)# 第二层:基于关键词进行精细化过滤refined_results = []for result in coarse_results:# 计算关键词匹配度keyword_match_score = 0content_lower = result.content.lower()for keyword in query_analysis.keywords:if keyword.lower() in content_lower:keyword_match_score += 1if query_analysis.keywords:keyword_match_score /= len(query_analysis.keywords)# 结合语义分数和关键词匹配分数combined_score = 0.7 * result.score + 0.3 * keyword_match_scoreresult.score = combined_scoreresult.retrieval_method = "hierarchical"refined_results.append(result)return sorted(refined_results, key=lambda x: x.score, reverse=True)[:top_k]def _multi_query_retrieval(self, query: str, query_analysis: QueryAnalysis,top_k: int) -> List[RetrievalResult]:"""多查询检索:生成多个相关查询进行检索"""# 生成相关查询related_queries = self._generate_related_queries(query, query_analysis)all_results = []# 对每个查询进行检索for related_query in related_queries:query_results = self._semantic_retrieval(related_query, top_k // len(related_queries) + 1)# 降低相关查询的权重for result in query_results:result.score *= 0.8result.retrieval_method = "multi_query"all_results.extend(query_results)# 去重并排序unique_results = {}for result in all_results:if result.document_id not in unique_results:unique_results[result.document_id] = resultelse:# 保留分数更高的结果if result.score > unique_results[result.document_id].score:unique_results[result.document_id] = resultreturn sorted(unique_results.values(), key=lambda x: x.score, reverse=True)[:top_k]def _generate_related_queries(self, query: str, query_analysis: QueryAnalysis) -> List[str]:"""生成相关查询"""related_queries = [query] # 包含原查询# 基于关键词组合生成新查询if len(query_analysis.keywords) > 1:for i in range(len(query_analysis.keywords)):for j in range(i + 1, len(query_analysis.keywords)):related_query = f"{query_analysis.keywords[i]} {query_analysis.keywords[j]}"related_queries.append(related_query)# 基于实体生成查询for entity in query_analysis.entities:related_queries.append(entity)return related_queries[:4] # 限制查询数量def _fuse_and_rerank(self, all_results: List[RetrievalResult],query_analysis: QueryAnalysis, top_k: int) -> List[RetrievalResult]:"""结果 融合和重排序"""# 去重:合并相同文档的结果doc_results = {}for result in all_results:doc_id = result.document_idif doc_id not in doc_results:doc_results[doc_id] = resultelse:# 融合分数:取最大值并加权平均existing_result = doc_results[doc_id]fused_score = max(existing_result.score, result.score) * 0.6 + \(existing_result.score + result.score) / 2 * 0.4existing_result.score = fused_scoreexisting_result.retrieval_method += f"+{result.retrieval_method}"unique_results = list(doc_results.values())# 重排序:应用查询特定的排序策略reranked_results = self._apply_reranking(unique_results, query_analysis)return reranked_results[:top_k]def _apply_reranking(self, results: List[RetrievalResult],query_analysis: QueryAnalysis) -> List[RetrievalResult]:"""应用重排序策略"""for result in results:# 基础分数rerank_score = result.score# 长度偏好调整content_length = len(result.content)if query_analysis.query_type == QueryType.PROCEDURAL:# 程序性查询偏好较长的内容length_bonus = min(content_length / 1000, 0.2)else:# 其他查询类型偏好适中长度length_bonus = max(0, 0.1 - abs(content_length - 500) / 5000)rerank_score += length_bonus# 新鲜度调整(如果有时间戳信息)if 'timestamp' in result.metadata:# 简化的新鲜度计算freshness_bonus = 0.05 # 假设都是相对新的内容rerank_score += freshness_bonusresult.score = rerank_scorereturn sorted(results, key=lambda x: x.score, reverse=True)# 检索策略优化演示def demonstrate_retrieval_optimization():"""演示检索策略优化"""print("🎯 检索策略优化演示")print("=" * 50)# 创建模拟文档库text_corpus = {"doc_001": "Python是一种高级编程语言,广泛用于Web开发、数据分析和人工智能。它具有简洁的语法和强大的库生态系统。","doc_002": "机器学习是人工智能的一个重要分支,通过算法让计算机从数据中学习模式。常用的算法包括线性回归、决策树和神经网络。","doc_003": "Web开发涉及前端和后端技术。前端使用HTML、CSS和JavaScript,后端可以使用Python、Java或Node.js等技术。","doc_004": "数据库是存储和管理数据的系统。常见的数据库包括MySQL、PostgreSQL和MongoDB。SQL是查询关系数据库的标准语言。","doc_005": "深度学习是机器学习的一个子领域,使用多层神经网络来处理复杂的数据。在图像识别和自然语言处理方面表现出色。","doc_006": "API(应用程序编程接口)允许不同软件系统之间进行通信。RESTful API是目前最流行的API设计风格。","doc_007": "云计算提供了可扩展的计算资源,包括基础设施即服务(IaaS)、平台即服务(PaaS)和软件即服务(SaaS)。","doc_008": "版本控制系统如Git帮助开发者管理代码变更。GitHub是最流行的Git托管平台,支持协作开发。","doc_009": "容器技术如Docker简化了应用部署。Kubernetes是容器编排平台,用于管理大规模容器化应用。","doc_010": "测试驱动开发(TDD)是一种软件开发方法,要求先编写测试,再编写实现代码。这有助于提高代码质量。"}# 创建模拟向量数据库class MockVectorDatabase:def search_vectors(self, query_vector, k=10):# 模拟向量搜索结果return []vector_db = MockVectorDatabase()# 创建高级检索器retriever = AdvancedRetriever(vector_db, text_corpus)# 测试不同类型的查询test_queries = ["什么是Python编程语言?", # 事实性查询"为什么机器学习在AI中很重要?", # 分析性查询"如何使用Git进行版本控制?", # 程序性查询"设计一个Web API的最佳实践", # 创造性查询"比较SQL和NoSQL数据库的优缺点" # 比较性查询]for i, query in enumerate(test_queries, 1):print(f"\n🔍 测试查询 {i}: {query}")print("-" * 40)# 执行检索start_time = time.time()results = retriever.retrieve(query, top_k=3)retrieval_time = time.time() - start_timeprint(f"⏱️ 检索耗时: {retrieval_time:.3f}秒")print(f"📋 检索结果 (前3个):")for j, result in enumerate(results, 1):print(f" {j}. 文档ID: {result.document_id}")print(f" 分数: {result.score:.4f}")print(f" 方法: {result.retrieval_method}")print(f" 内容: {result.content[:100]}...")print()print("✅ 检索策略优化演示完成")# 运行检索策略优化演示demonstrate_retrieval_optimization()print("\n🎉 检索策略优化完整演示结束")
📊 检索性能评估
为了持续优化检索效果,我们需要建立完善的评估体系:
🚀 检索策略进阶技术
class AdvancedRetrievalTechniques:"""高级检索技术集合"""def __init__(self):self.query_expansion_cache = {}self.feedback_history = []def query_expansion(self, query: str, expansion_type: str = "synonym") -> List[str]:"""查询扩展技术"""if query in self.query_expansion_cache:return self.query_expansion_cache[query]expanded_queries = [query]if expansion_type == "synonym":# 同义词扩展(简化版)synonym_map = {"编程": ["程序设计", "开发", "coding"],"数据库": ["DB", "数据存储", "database"],"机器学习": ["ML", "人工智能", "AI算法"],"网络": ["网络", "互联网", "web"]}for original, synonyms in synonym_map.items():if original in query:for synonym in synonyms:expanded_queries.append(query.replace(original, synonym))elif expansion_type == "context":# 上下文扩展context_keywords = {"Python": ["编程语言", "脚本", "开发"],"API": ["接口", "服务", "调用"],"数据": ["信息", "统计", "分析"]}for term, contexts in context_keywords.items():if term in query:for context in contexts:expanded_queries.append(f"{query} {context}")self.query_expansion_cache[query] = expanded_queriesreturn expanded_queriesdef pseudo_relevance_feedback(self, query: str, initial_results: List[RetrievalResult],feedback_docs: int = 3) -> str:"""伪相关反馈技术"""if len(initial_results) < feedback_docs:return query# 从top结果中提取关键词feedback_keywords = []for result in initial_results[:feedback_docs]:# 简单的关键词提取words = re.findall(r'\b\w{4,}\b', result.content.lower())word_freq = Counter(words)# 选择高频词作为扩展词top_words = [word for word, freq in word_freq.most_common(5) if freq > 1]feedback_keywords.extend(top_words)# 选择最相关的扩展词keyword_freq = Counter(feedback_keywords)expansion_words = [word for word, freq in keyword_freq.most_common(3)]# 构建扩展查询expanded_query = query + " " + " ".join(expansion_words)return expanded_querydef learning_to_rank(self, results: List[RetrievalResult],query_features: Dict[str, float]) -> List[RetrievalResult]:"""学习排序技术(简化版)"""# 特征权重(在实际应用中,这些权重会通过机器学习训练得到)feature_weights = {"semantic_score": 0.4,"keyword_match": 0.3,"content_length": 0.1,"freshness": 0.1,"authority": 0.1}for result in results:# 计算综合排序分数ranking_score = 0# 语义相似度分数ranking_score += result.score * feature_weights["semantic_score"]# 关键词匹配分数keyword_match_score = self._calculate_keyword_match(result, query_features)ranking_score += keyword_match_score * feature_weights["keyword_match"]# 内容长度分数length_score = min(len(result.content) / 1000, 1.0)ranking_score += length_score * feature_weights["content_length"]# 新鲜度分数(模拟)freshness_score = 0.8 # 假设内容相对新鲜ranking_score += freshness_score * feature_weights["freshness"]# 权威性分数(模拟)authority_score = 0.7 # 假设来源权威性ranking_score += authority_score * feature_weights["authority"]result.score = ranking_scorereturn sorted(results, key=lambda x: x.score, reverse=True)def _calculate_keyword_match(self, result: RetrievalResult,query_features: Dict[str, float]) -> float:"""计算关键词匹配分数"""content_lower = result.content.lower()match_score = 0# 简化的关键词匹配计算for feature, weight in query_features.items():if feature.lower() in content_lower:match_score += weightreturn min(match_score, 1.0)def diversification(self, results: List[RetrievalResult],diversity_threshold: float = 0.7) -> List[RetrievalResult]:"""结果多样化技术"""if len(results) <= 1:return resultsdiversified_results = [results[0]] # 保留最相关的结果for result in results[1:]:# 检查与已选结果的相似度is_diverse = Truefor selected_result in diversified_results:similarity = self._calculate_content_similarity(result, selected_result)if similarity > diversity_threshold:is_diverse = Falsebreakif is_diverse:diversified_results.append(result)return diversified_resultsdef _calculate_content_similarity(self, result1: RetrievalResult,result2: RetrievalResult) -> float:"""计算内容相似度(简化版)"""content1_words = set(re.findall(r'\b\w+\b', result1.content.lower()))content2_words = set(re.findall(r'\b\w+\b', result2.content.lower()))if not content1_words or not content2_words:return 0.0intersection = len(content1_words & content2_words)union = len(content1_words | content2_words)return intersection / union if union > 0 else 0.0# 高级技术演示def demonstrate_advanced_techniques():"""演示高级检索技术"""print("🚀 高级检索技术演示")print("=" * 40)techniques = AdvancedRetrievalTechniques()# 测试查询扩展print("📈 查询扩展技术:")original_query = "Python机器学习"synonym_expansion = techniques.query_expansion(original_query, "synonym")print(f"同义词扩展: {synonym_expansion}")context_expansion = techniques.query_expansion(original_query, "context")print(f"上下文扩展: {context_expansion}")# 测试伪相关反馈print(f"\n🔄 伪相关反馈:")mock_results = [RetrievalResult("doc1", "Python是机器学习的重要工具,提供了scikit-learn等库", 0.9, {}, "mock"),RetrievalResult("doc2", "深度学习框架TensorFlow和PyTorch都支持Python", 0.8, {}, "mock"),RetrievalResult("doc3", "数据科学家经常使用Python进行数据分析和建模", 0.7, {}, "mock")]feedback_query = techniques.pseudo_relevance_feedback(original_query, mock_results)print(f"反馈扩展查询: {feedback_query}")# 测试结果多样化print(f"\n🎨 结果多样化:")diverse_results = techniques.diversification(mock_results)print(f"多样化后结果数量: {len(diverse_results)}")print("✅ 高级技术演示完成")# 运行高级技术演示demonstrate_advanced_techniques()print("\n🎯 检索策略优化章节完成")
通过本节的学习,我们掌握了检索策略优化的核心技术,包括查询分析、策略选择、结果融合和高级优化技术。这些技术的合理应用可以显著提升RAG 系统的检索精度和用户体验。
本节我们学习了检索策略优化的理论和实践。下一节我们将探讨生成策略优化,了解如何提升RAG系统的生成质量和一致性。
29.5 生成策略优化
🎯 答案质量保障系统
在我们的知识检索中心中,答案质量保障系统是整个流程的最后一环,也是最关键的环节。就像一个经验丰富的专家顾问,它需要基于检索到的信息,生成准确、相关、连贯的高质量答案。
📝 生成策略核心概念
生成策略优化涉及提示工程、上下文管理、答案质量控制、一致性保证等多个方面,目标是确保RAG系统输出高质量、可信赖的答案。
# 生成策略优化实现import numpy as npimport reimport jsonfrom typing import List, Dict, Any, Optional, Tuple, Unionfrom dataclasses import dataclassfrom abc import ABC, abstractmethodfrom enum import Enumimport timefrom collections import defaultdictimport hashlibclass GenerationStrategy(Enum):"""生成策略枚举"""EXTRACTIVE = "extractive" # 抽取式生成ABSTRACTIVE = "abstractive" # 抽象式生成HYBRID = "hybrid" # 混合式生成TEMPLATE_BASED = "template" # 模板式生成CHAIN_OF_THOUGHT = "cot" # 思维链生成class AnswerQuality(Enum):"""答案质量等级"""EXCELLENT = "excellent" # 优秀GOOD = "good" # 良好FAIR = "fair" # 一般POOR = "poor" # 较差INVALID = "invalid" # 无效@dataclassclass GenerationContext:"""生成上下文"""query: strretrieved_documents: List[Dict[str, Any]]conversation_history: List[Dict[str, str]]user_preferences: Dict[str, Any]domain: strlanguage: str = "zh"@dataclassclass GeneratedAnswer:"""生成的答案"""content: strconfidence_score: floatsources: List[str]generation_method: strquality_metrics: Dict[str, float]metadata: Dict[str, Any]class PromptTemplate:"""提示模板类"""def __init__(self):# 预定义的提示模板self.templates = {"factual": {"system": "你是一个专业的知识助手,基于提供的文档回答用户问题。请确保答案准确、简洁。","user": """基于以下文档内容回答问题:文档内容:{documents}问题:{query}请提供准确、基于文档的答案。如果文档中没有相关信息,请明确说明。"""},"analytical": {"system": "你是一个分析专家,能够深入分析问题并提供有见地的答案。","user": """基于以下文档内容,深入分析并回答问题:文档内容:{documents}问题:{query}请提供:1. 核心观点分析2. 支撑证据3. 可能的影响或结论4. 如有不确定性,请明确指出"""},"procedural": {"system": "