diff --git a/.env.example b/.env.example index 1213de1..4c91c22 100644 --- a/.env.example +++ b/.env.example @@ -10,4 +10,7 @@ ENABLE_CONVERSATION_MEMORY=true # Server Configuration SERVER_HOST="127.0.0.1" -SERVER_PORT=5000 \ No newline at end of file +SERVER_PORT=5000 + +# OpenAI compatable API Keys +VALID_API_KEYS="sk-abc123,sk-def456" \ No newline at end of file diff --git a/main.py b/main.py index 2e50e6e..dfcae72 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ import json import logging import asyncio -from flask import Flask, request, Response, stream_with_context +from flask import Flask, request, Response, stream_with_context, jsonify import httpx import time from dotenv import load_dotenv @@ -21,6 +21,9 @@ logging.getLogger("httpx").setLevel(logging.WARNING) # 加载环境变量 load_dotenv() +# 从环境变量读取有效的API密钥(逗号分隔) +VALID_API_KEYS = [key.strip() for key in os.getenv("VALID_API_KEYS", "").split(",") if key] + # 获取会话ID记忆功能开关配置 ENABLE_CONVERSATION_MEMORY = os.getenv('ENABLE_CONVERSATION_MEMORY', 'true').lower() == 'true' @@ -312,6 +315,41 @@ def decode_conversation_id(content): @app.route('/v1/chat/completions', methods=['POST']) def chat_completions(): try: + # 新增:验证API密钥 + auth_header = request.headers.get('Authorization') + if not auth_header: + return jsonify({ + "error": { + "message": "Missing Authorization header", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key" + } + }), 401 + + parts = auth_header.split() + if len(parts) != 2 or parts[0].lower() != 'bearer': + return jsonify({ + "error": { + "message": "Invalid Authorization header format. Expected: Bearer ", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key" + } + }), 401 + + provided_api_key = parts[1] + if provided_api_key not in VALID_API_KEYS: + return jsonify({ + "error": { + "message": "Invalid API key", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key" + } + }), 401 + + # 继续处理原始逻辑 openai_request = request.get_json() logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}") @@ -593,7 +631,11 @@ def list_models(): logger.info(f"Available models: {json.dumps(response, ensure_ascii=False)}") return response +# 在main.py的最后初始化时添加环境变量检查: if __name__ == '__main__': + if not VALID_API_KEYS: + print("Warning: No API keys configured. Set the VALID_API_KEYS environment variable with comma-separated keys.") + # 启动时初始化模型信息 asyncio.run(model_manager.refresh_model_info())