#!/usr/bin/env python3 """ PostgreSQL 数据库自动设置脚本 功能: 1. 自动连接到PostgreSQL服务器 2. 创建数据库和用户 3. 设置权限 4. 初始化表结构 使用方法: python backend/scripts/setup_postgres.py 前置条件: - PostgreSQL服务已安装并运行 - 知道PostgreSQL的超级用户密码(通常是postgres用户) """ import sys import asyncio from pathlib import Path from getpass import getpass import logging # 添加项目根目录到Python路径 sys.path.insert(0, str(Path(__file__).parent.parent)) try: import psycopg2 from psycopg2 import sql from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT except ImportError: print("❌ 缺少psycopg2依赖,正在安装...") import subprocess subprocess.check_call([sys.executable, "-m", "pip", "install", "psycopg2-binary"]) import psycopg2 from psycopg2 import sql from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT # 注意: 表结构应由 Alembic 管理 from pathlib import Path # 设置日志 logging.basicConfig( level=logging.INFO, format='%(message)s' ) logger = logging.getLogger(__name__) class PostgreSQLSetup: """PostgreSQL数据库自动设置""" def __init__( self, host: str = "localhost", port: int = 5432, admin_user: str = "postgres", admin_password: str = None, db_name: str = "mumuai_novel", db_user: str = "mumuai", db_password: str = "123456" ): """ 初始化设置参数 Args: host: PostgreSQL主机地址 port: PostgreSQL端口 admin_user: 管理员用户名 admin_password: 管理员密码 db_name: 要创建的数据库名 db_user: 要创建的用户名 db_password: 用户密码 """ self.host = host self.port = port self.admin_user = admin_user self.admin_password = admin_password self.db_name = db_name self.db_user = db_user self.db_password = db_password self.conn = None def connect_as_admin(self) -> bool: """连接到PostgreSQL(使用管理员权限)""" try: logger.info(f"🔌 连接到 PostgreSQL ({self.host}:{self.port})...") self.conn = psycopg2.connect( host=self.host, port=self.port, user=self.admin_user, password=self.admin_password, database="postgres" # 连接到默认数据库 ) self.conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) logger.info(f"✅ 已连接到 PostgreSQL") return True except psycopg2.OperationalError as e: logger.error(f"❌ 连接失败: {e}") logger.error("\n可能的原因:") logger.error("1. PostgreSQL服务未启动") logger.error("2. 管理员密码错误") logger.error("3. 主机地址或端口错误") logger.error("4. pg_hba.conf配置不允许连接") return False def database_exists(self) -> bool: """检查数据库是否存在""" cursor = self.conn.cursor() cursor.execute( "SELECT 1 FROM pg_database WHERE datname = %s", (self.db_name,) ) exists = cursor.fetchone() is not None cursor.close() return exists def user_exists(self) -> bool: """检查用户是否存在""" cursor = self.conn.cursor() cursor.execute( "SELECT 1 FROM pg_user WHERE usename = %s", (self.db_user,) ) exists = cursor.fetchone() is not None cursor.close() return exists def create_user(self) -> bool: """创建数据库用户""" try: if self.user_exists(): logger.info(f"ℹ️ 用户 '{self.db_user}' 已存在") # 询问是否重置密码 response = input(f"是否重置用户 '{self.db_user}' 的密码? (yes/no): ") if response.lower() in ['yes', 'y']: cursor = self.conn.cursor() cursor.execute( sql.SQL("ALTER USER {} WITH PASSWORD %s").format( sql.Identifier(self.db_user) ), (self.db_password,) ) cursor.close() logger.info(f"✅ 用户密码已更新") return True logger.info(f"👤 创建用户 '{self.db_user}'...") cursor = self.conn.cursor() cursor.execute( sql.SQL("CREATE USER {} WITH PASSWORD %s").format( sql.Identifier(self.db_user) ), (self.db_password,) ) cursor.close() logger.info(f"✅ 用户创建成功") return True except Exception as e: logger.error(f"❌ 创建用户失败: {e}") return False def create_database(self) -> bool: """创建数据库""" try: if self.database_exists(): logger.info(f"ℹ️ 数据库 '{self.db_name}' 已存在") # 询问是否删除重建 response = input(f"是否删除并重建数据库 '{self.db_name}'? (yes/no): ") if response.lower() in ['yes', 'y']: logger.warning(f"⚠️ 删除数据库 '{self.db_name}'...") cursor = self.conn.cursor() # 断开所有连接 cursor.execute( sql.SQL(""" SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = %s AND pid <> pg_backend_pid() """), (self.db_name,) ) cursor.execute( sql.SQL("DROP DATABASE {}").format( sql.Identifier(self.db_name) ) ) cursor.close() logger.info(f"✅ 数据库已删除") else: return True logger.info(f"🗄️ 创建数据库 '{self.db_name}'...") cursor = self.conn.cursor() cursor.execute( sql.SQL("CREATE DATABASE {} OWNER {}").format( sql.Identifier(self.db_name), sql.Identifier(self.db_user) ) ) cursor.close() logger.info(f"✅ 数据库创建成功") return True except Exception as e: logger.error(f"❌ 创建数据库失败: {e}") return False def grant_privileges(self) -> bool: """授予用户权限""" try: logger.info(f"🔐 授予用户权限...") cursor = self.conn.cursor() # 授予数据库所有权限 cursor.execute( sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {}").format( sql.Identifier(self.db_name), sql.Identifier(self.db_user) ) ) cursor.close() logger.info(f"✅ 权限授予成功") return True except Exception as e: logger.error(f"❌ 授予权限失败: {e}") return False def update_env_file(self) -> bool: """更新.env文件""" try: env_file = Path(__file__).parent.parent / ".env" database_url = ( f"postgresql+asyncpg://{self.db_user}:{self.db_password}" f"@{self.host}:{self.port}/{self.db_name}" ) if env_file.exists(): # 读取现有内容 with open(env_file, 'r', encoding='utf-8') as f: lines = f.readlines() # 更新DATABASE_URL updated = False for i, line in enumerate(lines): if line.startswith('DATABASE_URL='): lines[i] = f"DATABASE_URL={database_url}\n" updated = True break if not updated: lines.append(f"\nDATABASE_URL={database_url}\n") # 写回文件 with open(env_file, 'w', encoding='utf-8') as f: f.writelines(lines) else: # 创建新文件 with open(env_file, 'w', encoding='utf-8') as f: f.write(f"DATABASE_URL={database_url}\n") logger.info(f"✅ .env 文件已更新") logger.info(f" DATABASE_URL={database_url}") return True except Exception as e: logger.error(f"❌ 更新.env文件失败: {e}") return False async def initialize_tables(self) -> bool: """初始化数据库表结构(使用 Alembic)""" try: import subprocess logger.info(f"📋 使用 Alembic 初始化数据库表结构...") # 运行 Alembic 迁移 result = subprocess.run( ["alembic", "upgrade", "head"], capture_output=True, text=True, cwd=Path(__file__).parent.parent ) if result.returncode == 0: logger.info(f"✅ 表结构初始化成功") return True else: logger.error(f"❌ Alembic 迁移失败: {result.stderr}") return False except Exception as e: logger.error(f"❌ 初始化表结构失败: {e}") return False def close(self): """关闭数据库连接""" if self.conn: self.conn.close() logger.info(f"🔌 已断开连接") async def setup(self) -> bool: """执行完整设置流程""" try: # 1. 连接 if not self.connect_as_admin(): return False # 2. 创建用户 if not self.create_user(): return False # 3. 创建数据库 if not self.create_database(): return False # 4. 授予权限 if not self.grant_privileges(): return False # 5. 更新配置 if not self.update_env_file(): return False # 6. 关闭管理员连接 self.close() # 7. 初始化表结构 if not await self.initialize_tables(): return False return True except Exception as e: logger.error(f"❌ 设置过程出错: {e}") return False finally: if self.conn: self.close() async def main(): """主函数""" print(""" ╔═══════════════════════════════════════════════════════════════╗ ║ PostgreSQL 数据库自动设置工具 ║ ║ ║ ║ 此工具将自动完成: ║ ║ 1. 连接到PostgreSQL服务器 ║ ║ 2. 创建数据库和用户 ║ ║ 3. 设置权限 ║ ║ 4. 初始化表结构 ║ ║ 5. 更新.env配置文件 ║ ╚═══════════════════════════════════════════════════════════════╝ """) # 获取配置 print("请输入PostgreSQL配置信息:\n") host = input("主机地址 [localhost]: ").strip() or "localhost" port = input("端口 [5432]: ").strip() or "5432" port = int(port) admin_user = input("管理员用户名 [postgres]: ").strip() or "postgres" admin_password = getpass(f"管理员密码: ") print("\n请输入要创建的数据库信息:\n") db_name = input("数据库名 [mumuai_novel]: ").strip() or "mumuai_novel" db_user = input("数据库用户名 [mumuai]: ").strip() or "mumuai" db_password = getpass("数据库用户密码 [mumuai123]: ") or "mumuai123" print(f"\n{'='*60}") print(f"配置摘要:") print(f" 服务器: {host}:{port}") print(f" 数据库: {db_name}") print(f" 用户: {db_user}") print(f"{'='*60}\n") response = input("确认开始设置? (yes/no): ") if response.lower() not in ['yes', 'y']: print("已取消设置") return # 执行设置 setup = PostgreSQLSetup( host=host, port=port, admin_user=admin_user, admin_password=admin_password, db_name=db_name, db_user=db_user, db_password=db_password ) print(f"\n{'='*60}") success = await setup.setup() print(f"{'='*60}\n") if success: print("🎉 PostgreSQL设置完成!\n") print("下一步:") print("1. 启动应用: python -m app.main") print("2. 访问: http://localhost:8000") print("3. 查看API文档: http://localhost:8000/docs") else: print("❌ 设置过程中出现错误,请检查日志") print("\n故障排查:") print("1. 确认PostgreSQL服务正在运行") print("2. 检查管理员用户名和密码") print("3. 查看PostgreSQL日志") if __name__ == "__main__": asyncio.run(main())