Commit 3cc88241 authored by uuo00_n's avatar uuo00_n

feat: 初始化LLM过滤系统项目

- 添加FastAPI后端基础结构
- 实现用户认证和授权功能
- 添加敏感词检测和过滤功能
- 实现对话管理API
- 添加管理员功能接口
- 配置MongoDB数据库连接
- 添加Ollama集成支持
- 初始化数据库脚本
parents
# 数据库配置
MONGODB_URL=mongodb://localhost:27017
DB_NAME=llm_filter_db
# JWT配置
SECRET_KEY=your_secret_key_here
ALGORITHM=HS256
ACCESS_TOKEN_EXPIRE_MINUTES=30
# Ollama配置
OLLAMA_BASE_URL=http://localhost:11434
OLLAMA_MODEL=llama2
\ No newline at end of file
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from app.services.auth import get_current_user
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_active_user(token: str = Depends(oauth2_scheme)):
"""获取当前活跃用户"""
user = await get_current_user(token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
return user
async def get_current_admin_user(current_user: dict = Depends(get_current_active_user)):
"""获取当前管理员用户"""
if current_user.get("role") != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足,需要管理员权限"
)
return current_user
\ No newline at end of file
from fastapi import APIRouter, Depends, HTTPException, status
from typing import List, Optional
from datetime import datetime
from app.api.deps import get_current_admin_user
from app.schemas.sensitive_word import SensitiveWordCreate, SensitiveWordResponse, SensitiveRecordResponse
from app.services.sensitive_word import add_sensitive_word, delete_sensitive_word, get_all_sensitive_words, get_sensitive_records
router = APIRouter()
@router.post("/sensitive-words", response_model=dict, status_code=status.HTTP_201_CREATED)
async def create_sensitive_word(
word_data: SensitiveWordCreate,
_: dict = Depends(get_current_admin_user)
):
"""添加敏感词(仅管理员)"""
word_id = await add_sensitive_word(word_data.word, word_data.category)
return {"id": word_id}
@router.delete("/sensitive-words/{word_id}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_sensitive_word(
word_id: str,
_: dict = Depends(get_current_admin_user)
):
"""删除敏感词(仅管理员)"""
success = await delete_sensitive_word(word_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="敏感词不存在"
)
return None
@router.get("/sensitive-words", response_model=List[SensitiveWordResponse])
async def list_sensitive_words(
_: dict = Depends(get_current_admin_user)
):
"""获取所有敏感词(仅管理员)"""
return await get_all_sensitive_words()
@router.get("/sensitive-records", response_model=List[SensitiveRecordResponse])
async def list_sensitive_records(
user_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
_: dict = Depends(get_current_admin_user)
):
"""获取敏感词记录(仅管理员)"""
return await get_sensitive_records(user_id, start_date, end_date)
\ No newline at end of file
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from app.core.config import settings
from app.db.mongodb import db
from app.schemas.user import UserCreate, UserResponse, Token
from app.services.auth import authenticate_user, create_access_token, get_password_hash
router = APIRouter()
@router.post("/register", response_model=UserResponse)
async def register(user_data: UserCreate):
"""注册新用户"""
# 检查用户名是否已存在
existing_user = await db.db.users.find_one({"username": user_data.username})
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
# 检查邮箱是否已存在
existing_email = await db.db.users.find_one({"email": user_data.email})
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="邮箱已被注册"
)
# 创建新用户
hashed_password = get_password_hash(user_data.password)
user = {
"username": user_data.username,
"email": user_data.email,
"hashed_password": hashed_password,
"role": "user" # 默认为普通用户
}
result = await db.db.users.insert_one(user)
# 获取创建的用户
created_user = await db.db.users.find_one({"_id": result.inserted_id})
return {
"id": str(created_user["_id"]),
"username": created_user["username"],
"email": created_user["email"],
"role": created_user["role"]
}
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
"""用户登录"""
user = await authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 创建访问令牌
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": str(user["_id"])},
expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
\ No newline at end of file
from fastapi import APIRouter, Depends, HTTPException, status
from app.api.deps import get_current_active_user
from app.schemas.conversation import MessageCreate, ConversationResponse
from app.services.conversation import create_conversation, get_conversation, add_message, get_user_conversations
router = APIRouter()
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_new_conversation(current_user: dict = Depends(get_current_active_user)):
"""创建新对话"""
conversation_id = await create_conversation(str(current_user["_id"]))
return {"id": conversation_id}
@router.get("/", response_model=list)
async def list_conversations(current_user: dict = Depends(get_current_active_user)):
"""获取用户的所有对话"""
conversations = await get_user_conversations(str(current_user["_id"]))
return conversations
@router.get("/{conversation_id}", response_model=dict)
async def get_single_conversation(
conversation_id: str,
current_user: dict = Depends(get_current_active_user)
):
"""获取单个对话"""
conversation = await get_conversation(conversation_id, str(current_user["_id"]))
if not conversation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在"
)
return conversation
@router.post("/{conversation_id}/messages")
async def send_message(
conversation_id: str,
message: MessageCreate,
current_user: dict = Depends(get_current_active_user)
):
"""发送消息并获取回复"""
# 检查对话是否存在
conversation = await get_conversation(conversation_id, str(current_user["_id"]))
if not conversation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在"
)
# 添加消息并获取回复
result = await add_message(conversation_id, str(current_user["_id"]), message.content)
return result
\ No newline at end of file
from fastapi import APIRouter
from app.api.v1 import auth, conversation, admin
api_router = APIRouter()
# 注册各模块路由
api_router.include_router(auth.router, prefix="/auth", tags=["认证"])
api_router.include_router(conversation.router, prefix="/conversations", tags=["对话"])
api_router.include_router(admin.router, prefix="/admin", tags=["管理员"])
\ No newline at end of file
import os
from pydantic import BaseSettings
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
class Settings(BaseSettings):
# 应用配置
APP_NAME: str = "LLM过滤系统"
API_V1_STR: str = "/api/v1"
# 数据库配置
MONGODB_URL: str = os.getenv("MONGODB_URL", "mongodb://localhost:27017")
DB_NAME: str = os.getenv("DB_NAME", "llm_filter_db")
# JWT配置
SECRET_KEY: str = os.getenv("SECRET_KEY", "your_secret_key_here")
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
# Ollama配置
OLLAMA_BASE_URL: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OLLAMA_MODEL: str = os.getenv("OLLAMA_MODEL", "llama2")
settings = Settings()
\ No newline at end of file
from motor.motor_asyncio import AsyncIOMotorClient
from app.core.config import settings
class MongoDB:
client: AsyncIOMotorClient = None
db = None
db = MongoDB()
async def connect_to_mongo():
"""连接到MongoDB数据库"""
db.client = AsyncIOMotorClient(settings.MONGODB_URL)
db.db = db.client[settings.DB_NAME]
print("Connected to MongoDB")
async def close_mongo_connection():
"""关闭MongoDB连接"""
if db.client:
db.client.close()
print("Closed MongoDB connection")
\ No newline at end of file
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.v1.router import api_router
from app.core.config import settings
from app.db.mongodb import connect_to_mongo, close_mongo_connection
from app.utils.sensitive_word_filter import sensitive_word_filter
app = FastAPI(
title=settings.APP_NAME,
openapi_url=f"{settings.API_V1_STR}/openapi.json"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该限制为特定域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册API路由
app.include_router(api_router, prefix=settings.API_V1_STR)
@app.on_event("startup")
async def startup_db_client():
"""应用启动时连接数据库并加载敏感词"""
await connect_to_mongo()
await sensitive_word_filter.load_sensitive_words()
@app.on_event("shutdown")
async def shutdown_db_client():
"""应用关闭时断开数据库连接"""
await close_mongo_connection()
@app.get("/")
async def root():
"""根路径,返回应用信息"""
return {
"app_name": settings.APP_NAME,
"version": "1.0.0",
"message": "欢迎使用LLM过滤系统API"
}
\ No newline at end of file
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
from bson import ObjectId
from app.models.user import PyObjectId
# 消息模型
class MessageModel(BaseModel):
role: str # "user" 或 "assistant"
content: str
timestamp: datetime = Field(default_factory=datetime.now)
contains_sensitive_words: bool = False
sensitive_words_found: List[str] = []
class Config:
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
# 对话模型
class ConversationModel(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
user_id: PyObjectId
messages: List[MessageModel] = []
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Config:
allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
\ No newline at end of file
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
from bson import ObjectId
from app.models.user import PyObjectId
# 敏感词模型
class SensitiveWordModel(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
word: str
category: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Config:
allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
# 敏感词记录模型
class SensitiveRecordModel(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
user_id: PyObjectId
conversation_id: PyObjectId
message_content: str
sensitive_words_found: List[str]
timestamp: datetime = Field(default_factory=datetime.now)
class Config:
allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
\ No newline at end of file
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
from bson import ObjectId
# 自定义ObjectId字段
class PyObjectId(ObjectId):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if not ObjectId.is_valid(v):
raise ValueError("无效的ObjectId")
return ObjectId(v)
@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string")
# 用户模型
class UserModel(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
username: str
email: str
hashed_password: str
role: str = "user" # "user" 或 "admin"
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class Config:
allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
schema_extra = {
"example": {
"username": "user1",
"email": "user1@example.com",
"hashed_password": "hashed_password_here",
"role": "user",
}
}
\ No newline at end of file
from typing import List, Optional
from pydantic import BaseModel
from datetime import datetime
class MessageCreate(BaseModel):
content: str
class MessageResponse(BaseModel):
role: str
content: str
timestamp: datetime
contains_sensitive_words: bool
sensitive_words_found: List[str]
class ConversationResponse(BaseModel):
id: str
messages: List[MessageResponse]
created_at: datetime
updated_at: datetime
\ No newline at end of file
from typing import List, Optional
from pydantic import BaseModel
from datetime import datetime
class SensitiveWordCreate(BaseModel):
word: str
category: Optional[str] = None
class SensitiveWordResponse(BaseModel):
id: str
word: str
category: Optional[str] = None
created_at: datetime
class SensitiveRecordResponse(BaseModel):
id: str
user_id: str
conversation_id: str
message_content: str
sensitive_words_found: List[str]
timestamp: datetime
\ No newline at end of file
from typing import Optional
from pydantic import BaseModel, EmailStr
class UserCreate(BaseModel):
username: str
email: EmailStr
password: str
class UserLogin(BaseModel):
username: str
password: str
class UserResponse(BaseModel):
id: str
username: str
email: str
role: str
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
user_id: Optional[str] = None
\ No newline at end of file
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.core.config import settings
from app.db.mongodb import db
from bson import ObjectId
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
"""获取密码哈希值"""
return pwd_context.hash(password)
async def authenticate_user(username: str, password: str):
"""验证用户"""
user = await db.db.users.find_one({"username": username})
if not user:
return False
if not verify_password(password, user["hashed_password"]):
return False
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
async def get_current_user(token: str):
"""获取当前用户"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
user_id = payload.get("sub")
if user_id is None:
return None
user = await db.db.users.find_one({"_id": ObjectId(user_id)})
if user is None:
return None
return user
except JWTError:
return None
\ No newline at end of file
from typing import List, Dict, Any, Optional
from datetime import datetime
from bson import ObjectId
from app.db.mongodb import db
from app.services.ollama import generate_response
from app.utils.sensitive_word_filter import sensitive_word_filter
async def create_conversation(user_id: str) -> str:
"""创建新对话"""
conversation = {
"user_id": ObjectId(user_id),
"messages": [],
"created_at": datetime.now(),
"updated_at": datetime.now()
}
result = await db.db.conversations.insert_one(conversation)
return str(result.inserted_id)
async def get_conversation(conversation_id: str, user_id: str) -> Optional[Dict]:
"""获取对话"""
conversation = await db.db.conversations.find_one({
"_id": ObjectId(conversation_id),
"user_id": ObjectId(user_id)
})
if conversation:
conversation["_id"] = str(conversation["_id"])
conversation["user_id"] = str(conversation["user_id"])
return conversation
async def add_message(conversation_id: str, user_id: str, content: str) -> Dict[str, Any]:
"""
添加用户消息并获取AI回复
Args:
conversation_id: 对话ID
user_id: 用户ID
content: 用户消息内容
Returns:
Dict: 包含处理结果的字典
"""
# 检查敏感词
contains_sensitive, sensitive_words = sensitive_word_filter.check_text(content)
# 创建用户消息
user_message = {
"role": "user",
"content": content,
"timestamp": datetime.now(),
"contains_sensitive_words": contains_sensitive,
"sensitive_words_found": sensitive_words
}
# 更新对话
await db.db.conversations.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {"messages": user_message},
"$set": {"updated_at": datetime.now()}
}
)
# 如果包含敏感词,记录并返回拒绝回复
if contains_sensitive:
# 创建敏感词记录
sensitive_record = {
"user_id": ObjectId(user_id),
"conversation_id": ObjectId(conversation_id),
"message_content": content,
"sensitive_words_found": sensitive_words,
"timestamp": datetime.now()
}
await db.db.sensitive_records.insert_one(sensitive_record)
# 创建系统回复
assistant_message = {
"role": "assistant",
"content": "当前问题暂无法回答。",
"timestamp": datetime.now(),
"contains_sensitive_words": False,
"sensitive_words_found": []
}
# 更新对话
await db.db.conversations.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {"messages": assistant_message},
"$set": {"updated_at": datetime.now()}
}
)
return {
"contains_sensitive_words": True,
"sensitive_words_found": sensitive_words,
"assistant_response": "当前问题暂无法回答。"
}
# 获取对话历史
conversation = await db.db.conversations.find_one({"_id": ObjectId(conversation_id)})
messages = conversation.get("messages", [])
# 准备发送给模型的消息(最多取最近10条)
model_messages = [
{"role": msg["role"], "content": msg["content"]}
for msg in messages[-10:]
]
# 调用模型生成回复
assistant_response = await generate_response(model_messages)
# 创建助手回复消息
assistant_message = {
"role": "assistant",
"content": assistant_response,
"timestamp": datetime.now(),
"contains_sensitive_words": False,
"sensitive_words_found": []
}
# 更新对话
await db.db.conversations.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {"messages": assistant_message},
"$set": {"updated_at": datetime.now()}
}
)
return {
"contains_sensitive_words": False,
"sensitive_words_found": [],
"assistant_response": assistant_response
}
async def get_user_conversations(user_id: str) -> List[Dict]:
"""获取用户的所有对话"""
conversations = []
cursor = db.db.conversations.find({"user_id": ObjectId(user_id)}).sort("updated_at", -1)
async for conversation in cursor:
conversation["_id"] = str(conversation["_id"])
conversation["user_id"] = str(conversation["user_id"])
conversations.append(conversation)
return conversations
\ No newline at end of file
import httpx
from typing import List, Dict, Any
from app.core.config import settings
async def generate_response(messages: List[Dict[str, str]]) -> str:
"""
调用Ollama API生成回复
Args:
messages: 对话历史消息列表,格式为[{"role": "user", "content": "..."}, ...]
Returns:
str: 模型生成的回复
"""
# 转换消息格式为Ollama API所需格式
prompt = ""
for msg in messages:
role_prefix = "User: " if msg["role"] == "user" else "Assistant: "
prompt += f"{role_prefix}{msg['content']}\n"
prompt += "Assistant: "
# 构建请求数据
data = {
"model": settings.OLLAMA_MODEL,
"prompt": prompt,
"stream": False
}
try:
# 发送请求到Ollama API
async with httpx.AsyncClient() as client:
response = await client.post(
f"{settings.OLLAMA_BASE_URL}/api/generate",
json=data,
timeout=60.0
)
response.raise_for_status()
result = response.json()
return result.get("response", "抱歉,我无法生成回复。")
except Exception as e:
print(f"调用Ollama API出错: {str(e)}")
return "抱歉,模型服务暂时不可用。"
\ No newline at end of file
from typing import List, Dict, Any
from datetime import datetime
from bson import ObjectId
from app.db.mongodb import db
from app.utils.sensitive_word_filter import sensitive_word_filter
async def add_sensitive_word(word: str, category: str = None) -> str:
"""添加敏感词"""
sensitive_word = {
"word": word,
"category": category,
"created_at": datetime.now(),
"updated_at": datetime.now()
}
result = await db.db.sensitive_words.insert_one(sensitive_word)
# 更新敏感词过滤器
await sensitive_word_filter.load_sensitive_words()
return str(result.inserted_id)
async def delete_sensitive_word(word_id: str) -> bool:
"""删除敏感词"""
result = await db.db.sensitive_words.delete_one({"_id": ObjectId(word_id)})
# 更新敏感词过滤器
await sensitive_word_filter.load_sensitive_words()
return result.deleted_count > 0
async def get_all_sensitive_words() -> List[Dict[str, Any]]:
"""获取所有敏感词"""
sensitive_words = []
cursor = db.db.sensitive_words.find().sort("word", 1)
async for word in cursor:
word["_id"] = str(word["_id"])
sensitive_words.append(word)
return sensitive_words
async def get_sensitive_records(
user_id: str = None,
start_date: datetime = None,
end_date: datetime = None
) -> List[Dict[str, Any]]:
"""获取敏感词记录"""
query = {}
if user_id:
query["user_id"] = ObjectId(user_id)
if start_date and end_date:
query["timestamp"] = {"$gte": start_date, "$lte": end_date}
elif start_date:
query["timestamp"] = {"$gte": start_date}
elif end_date:
query["timestamp"] = {"$lte": end_date}
records = []
cursor = db.db.sensitive_records.find(query).sort("timestamp", -1)
async for record in cursor:
record["_id"] = str(record["_id"])
record["user_id"] = str(record["user_id"])
record["conversation_id"] = str(record["conversation_id"])
records.append(record)
return records
\ No newline at end of file
from typing import List, Set, Dict, Tuple
from app.db.mongodb import db
class TrieNode:
"""Trie树节点,用于敏感词匹配"""
def __init__(self):
self.children = {}
self.is_end_of_word = False
class SensitiveWordFilter:
def __init__(self):
self.root = TrieNode()
self.sensitive_words = set()
async def load_sensitive_words(self):
"""从数据库加载敏感词"""
self.sensitive_words = set()
self.root = TrieNode()
# 从数据库获取敏感词
cursor = db.db.sensitive_words.find({})
async for document in cursor:
word = document.get("word", "")
if word:
self.sensitive_words.add(word)
self._add_to_trie(word)
def _add_to_trie(self, word: str):
"""将敏感词添加到Trie树中"""
node = self.root
for char in word:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_end_of_word = True
def check_text(self, text: str) -> Tuple[bool, List[str]]:
"""
检查文本是否包含敏感词
Args:
text: 要检查的文本
Returns:
Tuple[bool, List[str]]: (是否包含敏感词, 找到的敏感词列表)
"""
if not text:
return False, []
found_words = []
text_lower = text.lower() # 转为小写进行匹配
# 遍历文本的每个字符作为起点
for i in range(len(text_lower)):
node = self.root
for j in range(i, len(text_lower)):
char = text_lower[j]
# 如果字符不在当前节点的子节点中,结束当前匹配
if char not in node.children:
break
node = node.children[char]
# 如果到达某个敏感词的结尾
if node.is_end_of_word:
word = text_lower[i:j+1]
found_words.append(word)
break
return len(found_words) > 0, found_words
# 创建全局敏感词过滤器实例
sensitive_word_filter = SensitiveWordFilter()
\ No newline at end of file
import asyncio
import motor.motor_asyncio
from datetime import datetime
from bson import ObjectId
from passlib.context import CryptContext
# 密码加密工具
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# MongoDB连接配置
MONGODB_URL = "mongodb://localhost:27017"
DB_NAME = "llm_filter_db"
async def init_db():
# 连接到MongoDB
client = motor.motor_asyncio.AsyncIOMotorClient(MONGODB_URL)
db = client[DB_NAME]
# 清空现有集合(如果存在)
collections = await db.list_collection_names()
for collection in collections:
await db[collection].drop()
print("已清空现有集合")
# 创建用户集合并添加假数据
admin_id = ObjectId()
user_id = ObjectId()
users = [
{
"_id": admin_id,
"username": "admin",
"email": "admin@example.com",
"hashed_password": pwd_context.hash("admin123"),
"role": "admin",
"created_at": datetime.now(),
"updated_at": datetime.now()
},
{
"_id": user_id,
"username": "user",
"email": "user@example.com",
"hashed_password": pwd_context.hash("user123"),
"role": "user",
"created_at": datetime.now(),
"updated_at": datetime.now()
}
]
await db.users.insert_many(users)
print(f"已创建用户集合并添加 {len(users)} 条记录")
# 创建敏感词集合并添加假数据
sensitive_words = [
{
"word": "赌博",
"category": "违法活动",
"created_at": datetime.now(),
"updated_at": datetime.now()
},
{
"word": "色情",
"category": "违法活动",
"created_at": datetime.now(),
"updated_at": datetime.now()
},
{
"word": "毒品",
"category": "违法活动",
"created_at": datetime.now(),
"updated_at": datetime.now()
},
{
"word": "诈骗",
"category": "违法活动",
"created_at": datetime.now(),
"updated_at": datetime.now()
},
{
"word": "暴力",
"category": "不良内容",
"created_at": datetime.now(),
"updated_at": datetime.now()
},
{
"word": "自杀",
"category": "不良内容",
"created_at": datetime.now(),
"updated_at": datetime.now()
},
{
"word": "政治敏感",
"category": "政治内容",
"created_at": datetime.now(),
"updated_at": datetime.now()
}
]
await db.sensitive_words.insert_many(sensitive_words)
print(f"已创建敏感词集合并添加 {len(sensitive_words)} 条记录")
# 创建对话集合并添加假数据
conversation_id = ObjectId()
conversations = [
{
"_id": conversation_id,
"user_id": user_id,
"messages": [
{
"role": "user",
"content": "你好,请问你是谁?",
"timestamp": datetime.now(),
"contains_sensitive_words": False,
"sensitive_words_found": []
},
{
"role": "assistant",
"content": "你好!我是一个AI助手,可以回答你的问题和提供帮助。有什么我可以帮你的吗?",
"timestamp": datetime.now(),
"contains_sensitive_words": False,
"sensitive_words_found": []
}
],
"created_at": datetime.now(),
"updated_at": datetime.now()
}
]
await db.conversations.insert_many(conversations)
print(f"已创建对话集合并添加 {len(conversations)} 条记录")
# 创建敏感词记录集合并添加假数据
sensitive_records = [
{
"user_id": user_id,
"conversation_id": conversation_id,
"message_content": "我想了解一下赌博的方法",
"sensitive_words_found": ["赌博"],
"timestamp": datetime.now()
}
]
await db.sensitive_records.insert_many(sensitive_records)
print(f"已创建敏感词记录集合并添加 {len(sensitive_records)} 条记录")
print("\n数据库初始化完成!")
print("\n测试账号:")
print("管理员账号: admin / admin123")
print("用户账号: user / user123")
if __name__ == "__main__":
asyncio.run(init_db())
\ No newline at end of file
fastapi==0.104.1
uvicorn==0.23.2
pydantic==2.4.2
pymongo==4.6.0
motor==3.3.1
python-jose==3.3.0
passlib==1.7.4
python-multipart==0.0.6
bcrypt==4.0.1
httpx==0.25.0
python-dotenv==1.0.0
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment