Unverified Commit 28ce2b74 authored by uuo00_n's avatar uuo00_n Committed by GitHub

Merge pull request #13 from uuo00n/feature/conversation-unify-and-title-delete-cors

feat: 对话 ID/标题统一、敏感词结构化、删除接口与 CORS 支持
parents aea76f67 789fdcbe
from fastapi import APIRouter, Depends, HTTPException, status
from typing import List
from app.api.deps import get_current_active_user, require_edition_for_mode
from app.schemas.conversation import MessageCreate, ConversationDocOut, CreatedId, MessageSendResult
from app.services.conversation import create_conversation, get_conversation, add_message, get_user_conversations
from app.schemas.conversation import (
MessageCreate,
ConversationDocOut,
ConversationResponse,
CreatedId,
MessageSendResult,
DeleteResult,
)
from app.services.conversation import (
create_conversation,
get_conversation,
add_message,
get_user_conversations,
delete_conversation,
)
# 在路由层挂载版别运行模式依赖,限制仅允许当前模式的用户访问
router = APIRouter(dependencies=[Depends(require_edition_for_mode())])
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=CreatedId)
async def create_new_conversation(current_user: dict = Depends(get_current_active_user)):
"""创建新对话"""
"""创建新对话(返回新建 ID)
用途:为当前用户新建会话,并返回新会话的 ID。
依赖:鉴权用户、版别运行模式。
"""
conversation_id = await create_conversation(str(current_user["_id"]))
return {"id": conversation_id}
@router.get("/", response_model=List[ConversationDocOut])
@router.get("/", response_model=List[ConversationResponse])
async def list_conversations(current_user: dict = Depends(get_current_active_user)):
"""获取用户的所有对话"""
"""获取用户的所有对话(列表优化)
用途:返回当前用户的对话列表,统一字段并降低负载(仅最近一条消息)。
"""
conversations = await get_user_conversations(str(current_user["_id"]))
return conversations
......@@ -24,7 +42,9 @@ async def get_single_conversation(
conversation_id: str,
current_user: dict = Depends(get_current_active_user)
):
"""获取单个对话"""
"""获取单个对话详情
用途:返回指定对话的完整消息列表,统一 ID 与标题字段。
"""
conversation = await get_conversation(conversation_id, str(current_user["_id"]))
if not conversation:
raise HTTPException(
......@@ -33,13 +53,29 @@ async def get_single_conversation(
)
return conversation
@router.delete("/{conversation_id}", response_model=DeleteResult)
async def remove_conversation(
conversation_id: str,
current_user: dict = Depends(get_current_active_user)
):
"""删除对话
用途:仅允许删除当前用户归属的对话,并清理关联敏感记录。
返回:删除结果(deleted: bool)
"""
ok = await delete_conversation(conversation_id, str(current_user["_id"]))
if not ok:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="对话不存在或无权限")
return {"deleted": True, "message": "删除成功"}
@router.post("/{conversation_id}/messages", response_model=MessageSendResult)
async def send_message(
conversation_id: str,
message: MessageCreate,
current_user: dict = Depends(get_current_active_user)
):
"""发送消息并获取回复"""
"""发送消息并获取回复
用途:在指定对话中发送用户消息;若命中敏感词返回拒绝回复并记录;否则返回模型生成的助手回复。
"""
# 检查对话是否存在
conversation = await get_conversation(conversation_id, str(current_user["_id"]))
if not conversation:
......
......@@ -2,6 +2,19 @@ from typing import List, Optional
from pydantic import BaseModel
from datetime import datetime
class SensitiveWordInfoResponse(BaseModel):
"""敏感词信息结构化模型
字段说明:
- word: 敏感词文本
- category: 一级分类
- subcategory: 二级分类
- severity: 严重程度(数值越大越严重)
"""
word: str
category: Optional[str] = None
subcategory: Optional[str] = None
severity: Optional[int] = None
class MessageCreate(BaseModel):
content: str
......@@ -10,17 +23,20 @@ class MessageResponse(BaseModel):
content: str
timestamp: datetime
contains_sensitive_words: bool
sensitive_words_found: List[str]
sensitive_words_found: List[SensitiveWordInfoResponse]
class ConversationResponse(BaseModel):
id: str
title: Optional[str] = None
messages: List[MessageResponse]
created_at: datetime
updated_at: datetime
class ConversationDocOut(BaseModel):
_id: str
id: str
_id: Optional[str] = None
user_id: Optional[str] = None
title: Optional[str] = None
messages: List[MessageResponse]
created_at: datetime
updated_at: datetime
......@@ -28,7 +44,11 @@ class ConversationDocOut(BaseModel):
class CreatedId(BaseModel):
id: str
class DeleteResult(BaseModel):
deleted: bool
message: Optional[str] = None
class MessageSendResult(BaseModel):
contains_sensitive_words: bool
sensitive_words_found: List[str]
sensitive_words_found: List[SensitiveWordInfoResponse]
assistant_response: str
\ No newline at end of file
......@@ -6,9 +6,16 @@ from app.services.ollama import generate_response
from app.utils.sensitive_word_filter import sensitive_word_filter
async def create_conversation(user_id: str) -> str:
"""创建新对话"""
"""创建新对话
用途:为指定用户创建一个新的对话,初始化空消息与默认标题。
入参:
- user_id: 当前用户的字符串 ID
返回:
- 新建对话的字符串 ID
"""
conversation = {
"user_id": ObjectId(user_id),
"title": f"新会话 {datetime.now().strftime('%m-%d %H:%M')}",
"messages": [],
"created_at": datetime.now(),
"updated_at": datetime.now()
......@@ -18,47 +25,120 @@ async def create_conversation(user_id: str) -> str:
return str(result.inserted_id)
async def get_conversation(conversation_id: str, user_id: str) -> Optional[Dict]:
"""获取对话"""
"""获取单个对话
用途:根据对话 ID 和用户 ID 获取该用户的单个对话,并规范返回字段。
入参:
- conversation_id: 对话ID
- user_id: 用户ID
返回:
- 对话字典(含 id/_id/title/messages 等)或 None
"""
conversation = await db.db.conversations.find_one({
"_id": ObjectId(conversation_id),
"user_id": ObjectId(user_id)
})
if conversation:
conversation["_id"] = str(conversation["_id"])
conversation["user_id"] = str(conversation["user_id"])
if not conversation:
return None
# 统一 ID 字段与用户 ID 字符串化
conversation["id"] = str(conversation["_id"])
conversation["_id"] = str(conversation["_id"]) # 兼容旧前端
conversation["user_id"] = str(conversation["user_id"])
# 标题兜底
if not conversation.get("title"):
conversation["title"] = f"新会话 {conversation.get('created_at', datetime.now()).strftime('%m-%d %H:%M')}"
# 将消息中的敏感词字符串列表转换为结构化对象
messages: List[Dict[str, Any]] = conversation.get("messages", [])
# 汇总所有可能的敏感词
all_words: List[str] = []
for m in messages:
sw = m.get("sensitive_words_found", [])
if sw and isinstance(sw, list) and (len(sw) == 0 or isinstance(sw[0], str)):
all_words.extend([w for w in sw if isinstance(w, str)])
unique_words = list(set(all_words))
word_detail_map: Dict[str, Dict[str, Any]] = {}
if unique_words:
cursor = db.db.sensitive_words.find({"word": {"$in": unique_words}})
async for doc in cursor:
word_detail_map[doc.get("word")] = {
"word": doc.get("word"),
"category": doc.get("category"),
"subcategory": doc.get("subcategory"),
"severity": doc.get("severity", 1),
}
for m in messages:
sw = m.get("sensitive_words_found", [])
if sw and isinstance(sw, list) and (len(sw) == 0 or isinstance(sw[0], str)):
details: List[Dict[str, Any]] = []
for w in sw:
if not isinstance(w, str):
continue
details.append(word_detail_map.get(w, {"word": w, "category": None, "subcategory": None, "severity": 1}))
m["sensitive_words_found"] = details
conversation["messages"] = messages
return conversation
async def add_message(conversation_id: str, user_id: str, content: str) -> Dict[str, Any]:
"""
添加用户消息并获取AI回复
Args:
用途:写入用户消息,进行敏感词检测;如有敏感词,记录审计并返回拒绝回复;否则调用模型生成回复。
入参:
conversation_id: 对话ID
user_id: 用户ID
content: 用户消息内容
Returns:
Dict: 包含处理结果的字典
返回:
Dict: 包含处理结果的字典(含 contains_sensitive_words / sensitive_words_found / assistant_response)
"""
# 预取对话用于判断是否首次消息与标题更新
conversation = await db.db.conversations.find_one({
"_id": ObjectId(conversation_id),
"user_id": ObjectId(user_id)
})
if not conversation:
raise ValueError("对话不存在或无权限")
is_first_message = len(conversation.get("messages", [])) == 0
current_title = conversation.get("title", "")
# 检查敏感词
check_result = sensitive_word_filter.check_text(content)
contains_sensitive = check_result["contains_sensitive_words"]
sensitive_words = check_result["sensitive_words_found"]
highest_severity = check_result["highest_severity"]
# 创建用户消息
# 详细敏感词信息(用于响应与审计)
detailed_words: List[Dict[str, Any]] = []
if contains_sensitive and sensitive_words:
cursor = db.db.sensitive_words.find({"word": {"$in": sensitive_words}})
async for doc in cursor:
detailed_words.append({
"word": doc.get("word"),
"category": doc.get("category"),
"subcategory": doc.get("subcategory"),
"severity": doc.get("severity", 1),
})
# 对未命中库的词提供兜底
known = set(dw["word"] for dw in detailed_words)
for w in sensitive_words:
if w not in known:
detailed_words.append({"word": w, "category": None, "subcategory": None, "severity": 1})
# 创建用户消息(消息中也保存结构化敏感词列表,便于后续展示)
user_message = {
"role": "user",
"content": content,
"timestamp": datetime.now(),
"contains_sensitive_words": contains_sensitive,
"sensitive_words_found": sensitive_words,
"sensitive_words_found": detailed_words if contains_sensitive else [],
"highest_severity": highest_severity
}
# 更新对话
# 更新对话(写入用户消息与更新时间)
await db.db.conversations.update_one(
{"_id": ObjectId(conversation_id)},
{
......@@ -66,20 +146,17 @@ async def add_message(conversation_id: str, user_id: str, content: str) -> Dict[
"$set": {"updated_at": datetime.now()}
}
)
# 首次用户消息或原标题以“新会话”开头时,用内容前20字更新标题
if is_first_message or (current_title.startswith("新会话") if current_title else True):
new_title = content.strip()[:20] or f"新会话 {datetime.now().strftime('%m-%d %H:%M')}"
await db.db.conversations.update_one(
{"_id": ObjectId(conversation_id)},
{"$set": {"title": new_title, "updated_at": datetime.now()}}
)
# 如果包含敏感词,记录并返回拒绝回复
if contains_sensitive:
# 补充敏感词详细信息
detailed_words = []
if sensitive_words:
cursor = db.db.sensitive_words.find({"word": {"$in": sensitive_words}})
async for doc in cursor:
detailed_words.append({
"word": doc.get("word"),
"category": doc.get("category"),
"subcategory": doc.get("subcategory"),
"severity": doc.get("severity", 1),
})
highest = highest_severity
if detailed_words:
highest = max([dw.get("severity", 1) for dw in detailed_words])
......@@ -116,23 +193,21 @@ async def add_message(conversation_id: str, user_id: str, content: str) -> Dict[
return {
"contains_sensitive_words": True,
"sensitive_words_found": sensitive_words,
"sensitive_words_found": detailed_words,
"assistant_response": "当前问题暂无法回答。"
}
# 获取对话历史
# 获取对话历史(最多取最近10条)
conversation = await db.db.conversations.find_one({"_id": ObjectId(conversation_id)})
messages = conversation.get("messages", [])
# 准备发送给模型的消息(最多取最近10条)
model_messages = [
{"role": msg["role"], "content": msg["content"]}
for msg in messages[-10:]
]
# 调用模型生成回复
assistant_response = await generate_response(model_messages)
# 创建助手回复消息
assistant_message = {
"role": "assistant",
......@@ -141,7 +216,7 @@ async def add_message(conversation_id: str, user_id: str, content: str) -> Dict[
"contains_sensitive_words": False,
"sensitive_words_found": []
}
# 更新对话
await db.db.conversations.update_one(
{"_id": ObjectId(conversation_id)},
......@@ -150,7 +225,7 @@ async def add_message(conversation_id: str, user_id: str, content: str) -> Dict[
"$set": {"updated_at": datetime.now()}
}
)
return {
"contains_sensitive_words": False,
"sensitive_words_found": [],
......@@ -158,13 +233,68 @@ async def add_message(conversation_id: str, user_id: str, content: str) -> Dict[
}
async def get_user_conversations(user_id: str) -> List[Dict]:
"""获取用户的所有对话"""
conversations = []
"""获取用户的所有对话(列表优化)
用途:返回用户的对话列表,统一 ID 字段与标题;为降低负载,仅返回最近一条消息。
入参:
- user_id: 用户ID
返回:
- 对话字典列表(每项仅含最近一条 messages)
"""
conversations: List[Dict[str, Any]] = []
cursor = db.db.conversations.find({"user_id": ObjectId(user_id)}).sort("updated_at", -1)
async for conversation in cursor:
conversation["_id"] = str(conversation["_id"])
conversation["user_id"] = str(conversation["user_id"])
conversations.append(conversation)
async for c in cursor:
c_id = str(c["_id"])
c_user_id = str(c["user_id"])
title = c.get("title") or f"新会话 {c.get('created_at', datetime.now()).strftime('%m-%d %H:%M')}"
# 仅返回最近一条消息
last_msg = c.get("messages", [])[-1:] # 列表切片保证仍是 List
# 结构化敏感词
if last_msg:
sw = last_msg[0].get("sensitive_words_found", [])
if sw and isinstance(sw, list) and (len(sw) == 0 or isinstance(sw[0], str)):
details: List[Dict[str, Any]] = []
if sw:
cursor_sw = db.db.sensitive_words.find({"word": {"$in": sw}})
known: Dict[str, Dict[str, Any]] = {}
async for doc in cursor_sw:
known[doc.get("word")] = {
"word": doc.get("word"),
"category": doc.get("category"),
"subcategory": doc.get("subcategory"),
"severity": doc.get("severity", 1),
}
for w in sw:
details.append(known.get(w, {"word": w, "category": None, "subcategory": None, "severity": 1}))
last_msg[0]["sensitive_words_found"] = details
conversations.append({
"id": c_id,
"_id": c_id,
"user_id": c_user_id,
"title": title,
"messages": last_msg,
"created_at": c.get("created_at"),
"updated_at": c.get("updated_at"),
})
return conversations
\ No newline at end of file
return conversations
async def delete_conversation(conversation_id: str, user_id: str) -> bool:
"""删除用户对话并清理关联敏感记录
用途:仅删除当前用户归属的对话;若删除成功,清理敏感审计记录。
入参:
- conversation_id: 对话ID
- user_id: 用户ID
返回:
- 是否删除成功(True/False)
"""
filter_cond = {"_id": ObjectId(conversation_id), "user_id": ObjectId(user_id)}
del_res = await db.db.conversations.delete_one(filter_cond)
if del_res.deleted_count != 1:
return False
# 清理敏感记录
await db.db.sensitive_records.delete_many({
"conversation_id": ObjectId(conversation_id),
"user_id": ObjectId(user_id)
})
return True
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment