# agent.py
import os
import requests
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
# from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
# 1. 加载并切分文档
loader = TextLoader("password-cn.txt", encoding="utf-8")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
splits = text_splitter.split_documents(docs)
# 2. 创建向量数据库(首次运行会自动 embedding)
# vectorstore = Chroma.from_documents(
# documents=splits,
# embedding=HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# )
# 替换顶部 import
# 替换创建 vectorstore 的地方
embedding = OllamaEmbeddings(
model="qwen2.5:7b" # 这里填你在 Ollama 里实际安装的模型名,比如 "qwen2.5:7b"
)
vectorstore = Chroma.from_documents(
documents=splits,
embedding=embedding
)
retriever = vectorstore.as_retriever()
# 3. 设置本地大模型
# 优先使用 Ollama,如果不可用则提示用户安装
try:
# 检查 Ollama 服务是否可用
response = requests.get("http://localhost:11434/api/tags", timeout=2)
if response.status_code == 200:
# 获取已安装的模型列表
models_data = response.json()
available_models = [model.get("name", "") for model in models_data.get("models", [])]
if not available_models:
raise RuntimeError("Ollama 中没有安装任何模型")
# 优先使用的模型列表(按优先级排序)
preferred_models = ["qwen2.5:7b", "qwen2.5:14b", "llama3.2:3b", "llama3:latest", "mistral:7b"]
# 选择模型:优先使用 preferred_models 中已安装的,否则使用第一个可用模型
selected_model = None
for preferred in preferred_models:
# 检查精确匹配
if preferred in available_models:
selected_model = preferred
break
# 检查部分匹配(如 preferred="llama3:latest" 匹配 available="llama3")
preferred_base = preferred.split(":")[0]
for available in available_models:
if available.startswith(preferred_base + ":") or available == preferred_base:
selected_model = available
break
if selected_model:
break
if not selected_model:
selected_model = available_models[0]
from langchain_community.chat_models import ChatOllama
llm = ChatOllama(
model=selected_model,
base_url="http://localhost:11434", # Ollama 默认地址
)
print(f"✓ 使用 Ollama 本地模型: {selected_model}")
if selected_model not in preferred_models[:3]: # 如果不是前3个推荐模型
print(f"💡 提示: 推荐安装 qwen2.5:7b 以获得更好的中文支持")
print(f" 运行: ollama pull qwen2.5:7b")
else:
raise ConnectionError("Ollama 服务未响应")
except requests.exceptions.ConnectionError as e:
# Ollama 服务不可用
error_msg = """
╔═══════════════════════════════════════════════════════════════╗
║ ❌ Ollama 服务未运行或未安装! ║
╚═══════════════════════════════════════════════════════════════╝
请按以下步骤操作:
1️⃣ 安装 Ollama:
• 访问 https://ollama.ai 下载安装
• 或使用 Homebrew: brew install ollama
2️⃣ 启动 Ollama 服务(安装后通常会自动启动):
ollama serve
3️⃣ 下载模型(选择一个):
ollama pull qwen2.5:7b # 推荐:中文支持好(约 4.7GB)
# 或
ollama pull llama3.2:3b # 更小更快(约 2GB)
4️⃣ 重新运行应用
═══════════════════════════════════════════════════════════════
"""
print(error_msg)
raise RuntimeError("Ollama 服务不可用。请先安装并启动 Ollama。")
except requests.exceptions.Timeout as e:
error_msg = "❌ 连接 Ollama 服务超时,请确保 Ollama 正在运行"
print(error_msg)
raise RuntimeError(error_msg)
except Exception as e:
# 其他错误(如模型不存在)
error_msg = f"""
╔═══════════════════════════════════════════════════════════════╗
║ ❌ 初始化 Ollama 模型时出错 ║
╚═══════════════════════════════════════════════════════════════╝
错误信息: {str(e)}
请确保:
1. Ollama 服务正在运行
2. 至少安装了一个模型(运行: ollama pull <模型名>)
═══════════════════════════════════════════════════════════════
"""
print(error_msg)
raise
# 4. 构建 RAG 链(Retrieval-Augmented Generation)
template = """
你是一个私人顾问,请根据以下上下文回答问题。
如果不知道答案,请说“根据现有文档无法回答”,不要编造。
上下文:
{context}
问题:
{question}
"""
prompt = ChatPromptTemplate.from_template(template)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# 5. 提供调用函数
# Gradio ChatInterface 默认会把 (message, history) 作为两个参数传进来
def ask_question(message: str, history=None) -> str:
# 我们这里只关心当前用户问题 message,忽略 history
return rag_chain.invoke(message)