Commit 6e2b6222 authored by uuo's avatar uuo

refactor(security-service): 重构Zabbix服务异步调用并简化API端点

- 在ZabbixService中添加_run_blocking方法处理阻塞调用兼容性
- 将collect_device_data等方法的同步调用改为异步,避免事件循环阻塞
- 简化API端点,将数据库日志记录逻辑移至SecurityService内部
- 统一时间处理,修复datetime.timezone.utc引用错误
- 重构测试脚本为多功能工具,支持Zabbix数据采集和API调用
- 优化应用启动逻辑,改进健康检查端点实现
parent 4d105058
from datetime import datetime
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from app.schemas.payloads import * from app.schemas.payloads import *
from app.services.analysis import SecurityService from app.services.analysis import SecurityService
from app.services.rss import RSSService from app.services.rss import RSSService
from app.services.zabbix_service import ZabbixService from app.services.zabbix_service import ZabbixService
from app.core.security import get_current_admin from app.core.security import get_current_admin
from app.core.database import db
from datetime import datetime, timezone
router = APIRouter() router = APIRouter()
zabbix_service = ZabbixService() zabbix_service = ZabbixService()
...@@ -14,15 +14,7 @@ rss_service = RSSService() ...@@ -14,15 +14,7 @@ rss_service = RSSService()
@router.post("/analysis", response_model=SecurityAnalysisResponse) @router.post("/analysis", response_model=SecurityAnalysisResponse)
async def analyze_risks(request: SecurityAnalysisRequest, admin: dict = Depends(get_current_admin)): async def analyze_risks(request: SecurityAnalysisRequest, admin: dict = Depends(get_current_admin)):
result = await service.analyze_risks(request.devices) return await service.analyze_risks(request.devices)
# 异步存储结果到 MongoDB
if db.db is not None:
log_entry = result.model_dump()
log_entry["created_at"] = datetime.now(timezone.utc)
await db.db.security_analysis_logs.insert_one(log_entry)
return result
@router.get("/analysis/history", response_model=HistoryQueryResponse) @router.get("/analysis/history", response_model=HistoryQueryResponse)
async def get_analysis_history( async def get_analysis_history(
...@@ -38,15 +30,12 @@ async def get_analysis_history( ...@@ -38,15 +30,12 @@ async def get_analysis_history(
@router.post("/attack-advice", response_model=AttackAdviceResponse) @router.post("/attack-advice", response_model=AttackAdviceResponse)
async def get_attack_advice(request: AttackAdviceRequest, admin: dict = Depends(get_current_admin)): async def get_attack_advice(request: AttackAdviceRequest, admin: dict = Depends(get_current_admin)):
result = await service.get_attack_advice(request.attack_type, request.target_device, request.logs) return await service.get_attack_advice(
attack_type=request.attack_type,
# 异步存储结果到 MongoDB target=request.target_device,
if db.db is not None: logs=request.logs,
log_entry = result.model_dump() severity=request.severity,
log_entry["created_at"] = datetime.now(timezone.utc) )
await db.db.attack_advice_logs.insert_one(log_entry)
return result
@router.get("/attack-advice/history", response_model=HistoryQueryResponse) @router.get("/attack-advice/history", response_model=HistoryQueryResponse)
async def get_attack_advice_history( async def get_attack_advice_history(
......
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime
import logging
from fastapi import FastAPI, HTTPException
from app.api.v1.endpoints import router as security_router from app.api.v1.endpoints import router as security_router
from app.core.config import settings from app.core.config import settings
from app.core.database import db from app.core.database import db
from datetime import datetime from app.services.zabbix_service import ZabbixService
import logging
import logging logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup db.connect()
logger = logging.getLogger(__name__)
try:
zabbix_service = ZabbixService()
sync_status = zabbix_service.get_sync_status()
if sync_status.get("collector_initialized"):
logger.info("Zabbix服务初始化成功")
else:
logger.warning("Zabbix服务初始化失败,请检查Zabbix配置")
except Exception as e:
logger.error(f"Zabbix服务连接检查失败: {e}")
yield
db.close()
app = FastAPI(title=settings.PROJECT_NAME, version="2.0.0", lifespan=lifespan)
@app.get("/") @app.get("/")
def health_check(): def health_check():
"""根路径健康检查"""
return {"status": "ok", "service": "security-service", "version": "2.0.0"} return {"status": "ok", "service": "security-service", "version": "2.0.0"}
@app.get("/health") @app.get("/health")
async def detailed_health_check(): async def detailed_health_check():
"""详细健康检查端点"""
try: try:
# 检查数据库连接
db_status = "ok" if db.db else "error" db_status = "ok" if db.db else "error"
# 检查Zabbix服务 zabbix_status = {"status": "not_configured"}
zabbix_status = {"status": "not_configured", "message": "Zabbix服务未配置"}
try: try:
from app.services.zabbix_service import ZabbixService
zabbix_service = ZabbixService() zabbix_service = ZabbixService()
sync_status = zabbix_service.get_sync_status() sync_status = zabbix_service.get_sync_status()
if sync_status["collector_initialized"]: if sync_status.get("collector_initialized"):
zabbix_status = { zabbix_status = {
"status": "ok", "status": "ok",
"last_sync": sync_status["last_sync_time"] "last_sync": sync_status.get("last_sync_time"),
} }
else: else:
zabbix_status = { zabbix_status = {
"status": "error", "status": "error",
"message": "Zabbix collector未初始化" "message": "Zabbix collector未初始化",
} }
except Exception as e: except Exception as e:
zabbix_status = { zabbix_status = {
"status": "error", "status": "error",
"message": str(e) "message": str(e),
} }
from datetime import datetime
return { return {
"status": "healthy", "status": "healthy",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
...@@ -54,63 +71,34 @@ async def detailed_health_check(): ...@@ -54,63 +71,34 @@ async def detailed_health_check():
"version": "2.0.0", "version": "2.0.0",
"components": { "components": {
"database": {"status": db_status}, "database": {"status": db_status},
"zabbix": zabbix_status "zabbix": zabbix_status,
} },
} }
except Exception as e: except Exception as e:
logger.error(f"Health check failed: {e}") logger.error(f"Health check failed: {e}")
raise HTTPException(status_code=503, detail=f"Service unhealthy: {str(e)}") raise HTTPException(status_code=503, detail=f"Service unhealthy: {str(e)}")
@app.get("/ready") @app.get("/ready")
async def readiness_check(): async def readiness_check():
"""就绪检查端点"""
try: try:
# 检查关键依赖是否就绪
if not db.db: if not db.db:
return { return {
"status": "not_ready", "status": "not_ready",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"reason": "database_not_connected" "reason": "database_not_connected",
} }
return { return {
"status": "ready", "status": "ready",
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat(),
} }
except Exception as e: except Exception as e:
return { return {
"status": "not_ready", "status": "not_ready",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"reason": str(e) "reason": str(e),
} }
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger = logging.getLogger(__name__)
# 连接数据库
db.connect()
# 初始化Zabbix服务连接检查
try:
from app.services.zabbix_service import ZabbixService
zabbix_service = ZabbixService()
sync_status = zabbix_service.get_sync_status()
if sync_status["collector_initialized"]:
logger.info("✅ Zabbix服务初始化成功")
else:
logger.warning("⚠️ Zabbix服务初始化失败,请检查Zabbix配置")
except Exception as e:
logger.error(f"❌ Zabbix服务连接检查失败: {e}")
yield
# Shutdown
db.close()
app = FastAPI(title=settings.PROJECT_NAME, lifespan=lifespan)
# 注册路由
app.include_router(security_router, prefix=f"{settings.API_V1_STR}/security", tags=["Security"]) app.include_router(security_router, prefix=f"{settings.API_V1_STR}/security", tags=["Security"])
import json import json
import httpx import httpx
from datetime import datetime from datetime import datetime, timezone
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from pydantic import ValidationError from pydantic import ValidationError
from app.schemas.payloads import * from app.schemas.payloads import *
...@@ -67,7 +67,7 @@ class SecurityService: ...@@ -67,7 +67,7 @@ class SecurityService:
try: try:
if db.db is not None: if db.db is not None:
log_entry = result.model_dump() log_entry = result.model_dump()
log_entry["created_at"] = datetime.now(datetime.timezone.utc) log_entry["created_at"] = datetime.now(timezone.utc)
log_entry["device_count"] = len(devices) log_entry["device_count"] = len(devices)
await db.db.security_analysis_logs.insert_one(log_entry) await db.db.security_analysis_logs.insert_one(log_entry)
logger.info("安全分析结果已保存到MongoDB") logger.info("安全分析结果已保存到MongoDB")
...@@ -76,23 +76,41 @@ class SecurityService: ...@@ -76,23 +76,41 @@ class SecurityService:
return result return result
async def get_attack_advice(self, attack_type: str, target: str, logs: str) -> AttackAdviceResponse: async def get_attack_advice(self, attack_type: str, target: str, logs: str, severity: Optional[str] = None) -> AttackAdviceResponse:
""" """
任务:攻击应急建议 任务:攻击应急建议
""" """
# 构造结构化数据
data = { data = {
"attack_type": attack_type, "attack_type": attack_type,
"target_device": target, "target_device": target,
"logs": logs "severity": severity,
"logs": logs,
} }
inputs = { inputs = {
"task_type": "advice", "task_type": "advice",
"context_data": json.dumps(data, ensure_ascii=False) "context_data": json.dumps(data, ensure_ascii=False),
} }
return await self._call_llm(inputs, AttackAdviceResponse) result = await self._call_llm(inputs, AttackAdviceResponse)
try:
if db.db is not None:
log_entry = result.model_dump()
log_entry.update(
{
"created_at": datetime.now(timezone.utc),
"attack_type": attack_type,
"target_device": target,
"severity": severity,
}
)
await db.db.attack_advice_logs.insert_one(log_entry)
logger.info("攻击建议结果已保存到MongoDB")
except Exception as e:
logger.error(f"保存攻击建议结果到MongoDB失败: {e}")
return result
async def generate_report(self) -> SecurityReportResponse: async def generate_report(self) -> SecurityReportResponse:
""" """
...@@ -172,7 +190,7 @@ class SecurityService: ...@@ -172,7 +190,7 @@ class SecurityService:
try: try:
if db.db is not None: if db.db is not None:
log_entry = result.model_dump() log_entry = result.model_dump()
log_entry["created_at"] = datetime.now(datetime.timezone.utc) log_entry["created_at"] = datetime.now(timezone.utc)
log_entry["report_date"] = report_data["date"] log_entry["report_date"] = report_data["date"]
log_entry["real_time_data"] = report_data.get("real_time_data", False) log_entry["real_time_data"] = report_data.get("real_time_data", False)
await db.db.security_report_logs.insert_one(log_entry) await db.db.security_report_logs.insert_one(log_entry)
...@@ -302,6 +320,10 @@ class SecurityService: ...@@ -302,6 +320,10 @@ class SecurityService:
return HistoryQueryResponse(total=total, items=items) return HistoryQueryResponse(total=total, items=items)
async def _call_llm(self, inputs: Dict[str, Any], model_cls): async def _call_llm(self, inputs: Dict[str, Any], model_cls):
if not settings.DIFY_API_URL or not settings.DIFY_API_KEY:
logger.warning("Dify 未配置(缺少 DIFY_API_URL 或 DIFY_API_KEY),跳过 LLM 调用并返回默认结果")
return self._build_default_response(model_cls)
try: try:
url = f"{settings.DIFY_API_URL.rstrip('/')}/chat-messages" url = f"{settings.DIFY_API_URL.rstrip('/')}/chat-messages"
headers = { headers = {
...@@ -327,15 +349,11 @@ class SecurityService: ...@@ -327,15 +349,11 @@ class SecurityService:
payload = { payload = {
"inputs": inputs, "inputs": inputs,
"query": query_prompt, "query": query_prompt,
"response_mode": "streaming", "response_mode": settings.DIFY_RESPONSE_MODE,
"conversation_id": "", "conversation_id": "",
"user": "security-system-api", "user": "security-system-api",
} }
# DEBUG LOG: 打印实际发送的 Payload
logger.info(f"Sending request to Dify. URL: {url}")
logger.info(f"Payload inputs: {json.dumps(inputs, ensure_ascii=False)}")
full_answer = "" full_answer = ""
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.stream("POST", url, json=payload, headers=headers, timeout=120.0) as resp: async with client.stream("POST", url, json=payload, headers=headers, timeout=120.0) as resp:
...@@ -437,7 +455,7 @@ class SecurityService: ...@@ -437,7 +455,7 @@ class SecurityService:
# 但通常思考在前,正文在后。如果没闭合,说明正文还没出来。 # 但通常思考在前,正文在后。如果没闭合,说明正文还没出来。
# 这里保守处理:如果剩下内容全是思考,那就全删了,返回空串,由上层处理为空的情况。 # 这里保守处理:如果剩下内容全是思考,那就全删了,返回空串,由上层处理为空的情况。
if '<think>' in text: if '<think>' in text:
text = re.sub(r'<think>.*', '', text, flags=re.DOTALL) text = re.sub(r'<think>.*', '', text, flags=re.DOTALL)
# 3. 清洗 Markdown 标记 # 3. 清洗 Markdown 标记
if "```json" in text: if "```json" in text:
......
...@@ -359,6 +359,14 @@ class ZabbixService: ...@@ -359,6 +359,14 @@ class ZabbixService:
self.last_sync_time = None self.last_sync_time = None
self._initialize_collector() self._initialize_collector()
async def _run_blocking(self, func, *args, **kwargs):
try:
to_thread = asyncio.to_thread
except AttributeError:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
return await to_thread(func, *args, **kwargs)
def _initialize_collector(self): def _initialize_collector(self):
"""初始化数据采集器""" """初始化数据采集器"""
try: try:
...@@ -376,22 +384,22 @@ class ZabbixService: ...@@ -376,22 +384,22 @@ class ZabbixService:
"""采集设备数据""" """采集设备数据"""
if not self.collector: if not self.collector:
raise Exception("Zabbix collector未初始化,请检查Zabbix配置") raise Exception("Zabbix collector未初始化,请检查Zabbix配置")
return self.collector.get_security_data_for_analysis() return await self._run_blocking(self.collector.get_security_data_for_analysis)
async def collect_cpu_data(self): async def collect_cpu_data(self):
"""采集CPU和硬件数据""" """采集CPU和硬件数据"""
if not self.collector: if not self.collector:
raise Exception("Zabbix collector未初始化,请检查Zabbix配置") raise Exception("Zabbix collector未初始化,请检查Zabbix配置")
return self.collector.get_cpu_data() return await self._run_blocking(self.collector.get_cpu_data)
async def collect_network_data(self): async def collect_network_data(self):
"""采集网络接口数据""" """采集网络接口数据"""
if not self.collector: if not self.collector:
raise Exception("Zabbix collector未初始化,请检查Zabbix配置") raise Exception("Zabbix collector未初始化,请检查Zabbix配置")
return self.collector.get_network_data() return await self._run_blocking(self.collector.get_network_data)
async def sync_data(self): async def sync_data(self):
"""同步数据""" """同步数据"""
...@@ -426,4 +434,4 @@ class ZabbixService: ...@@ -426,4 +434,4 @@ class ZabbixService:
return { return {
"last_sync_time": self.last_sync_time.isoformat() if self.last_sync_time else None, "last_sync_time": self.last_sync_time.isoformat() if self.last_sync_time else None,
"collector_initialized": self.collector is not None "collector_initialized": self.collector is not None
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment