feat: add email verification
This commit is contained in:
@@ -1,19 +1,28 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app.database import get_db, engine, Base
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserResponse
|
||||
from app.models.user import User, EmailVerification
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserResponse, ResendVerificationRequest
|
||||
from app.core.security import get_password_hash, verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
from datetime import timedelta
|
||||
from app.core.email import send_verification_email
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def generate_verification_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
@router.post("/auth/login")
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
@@ -45,7 +54,7 @@ def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depend
|
||||
}
|
||||
|
||||
@router.post("/auth/register", response_model=UserResponse)
|
||||
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
def register_user(user: UserCreate, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
db_user = db.query(User).filter(User.username == user.username).first()
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
@@ -58,19 +67,92 @@ def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
|
||||
# If this is the first user, make them an admin
|
||||
is_first_user = db.query(User).count() == 0
|
||||
is_admin = is_first_user or user.is_admin
|
||||
is_active = True if is_first_user else False
|
||||
|
||||
db_user = User(
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
hashed_password=hashed_password,
|
||||
is_active=True,
|
||||
is_admin=is_first_user or user.is_admin
|
||||
is_active=is_active,
|
||||
is_admin=is_admin
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
if not is_active:
|
||||
token = generate_verification_token()
|
||||
hashed = hash_token(token)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
verification = EmailVerification(
|
||||
user_id=db_user.id,
|
||||
token_hash=hashed,
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(verification)
|
||||
db.commit()
|
||||
|
||||
# 将用户的 email 保存到局部变量中,防止在后台任务执行前 session 关闭导致延迟加载失败
|
||||
user_email = db_user.email
|
||||
background_tasks.add_task(send_verification_email, user_email, token)
|
||||
|
||||
return db_user
|
||||
|
||||
@router.get("/auth/verify-email")
|
||||
def verify_email(token: str, db: Session = Depends(get_db)):
|
||||
hashed = hash_token(token)
|
||||
verification = db.query(EmailVerification).filter(
|
||||
EmailVerification.token_hash == hashed,
|
||||
EmailVerification.is_used == False
|
||||
).first()
|
||||
|
||||
if not verification:
|
||||
raise HTTPException(status_code=400, detail="Invalid or used token")
|
||||
|
||||
# Check if expired (make timezone-aware if naive)
|
||||
expires_at = verification.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
if expires_at < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=400, detail="Token expired")
|
||||
|
||||
user = db.query(User).filter(User.id == verification.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user.is_active = True
|
||||
verification.is_used = True
|
||||
db.commit()
|
||||
|
||||
return {"status": "success", "message": "Email verified successfully"}
|
||||
|
||||
@router.post("/auth/resend-verification")
|
||||
def resend_verification(request: ResendVerificationRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
user = db.query(User).filter(User.username == request.username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user.is_active:
|
||||
raise HTTPException(status_code=400, detail="User already active")
|
||||
|
||||
token = generate_verification_token()
|
||||
hashed = hash_token(token)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
verification = EmailVerification(
|
||||
user_id=user.id,
|
||||
token_hash=hashed,
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(verification)
|
||||
db.commit()
|
||||
|
||||
# 提取 email,避免后台任务访问已断开的 db session
|
||||
user_email = user.email
|
||||
background_tasks.add_task(send_verification_email, user_email, token)
|
||||
return {"status": "success", "message": "Verification email sent"}
|
||||
|
||||
@router.get("/users", response_model=List[UserResponse])
|
||||
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
users = db.query(User).offset(skip).limit(limit).all()
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import smtplib
|
||||
import os
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
def send_verification_email(to_email: str, token: str):
|
||||
smtp_host = os.getenv("SMTP_HOST", "smtp.qq.com")
|
||||
smtp_port = int(os.getenv("SMTP_PORT", "465"))
|
||||
smtp_user = os.getenv("SMTP_USER", "")
|
||||
smtp_password = os.getenv("SMTP_PASSWORD", "")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173")
|
||||
|
||||
if not smtp_user or not smtp_password:
|
||||
print("SMTP configuration is missing. Skip sending email.")
|
||||
return
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = smtp_user
|
||||
msg['To'] = to_email
|
||||
msg['Subject'] = "Please verify your email address"
|
||||
|
||||
verify_link = f"{frontend_url}/verify-email?token={token}"
|
||||
body = f"""
|
||||
<html>
|
||||
<body>
|
||||
<h2>Welcome to DataClaw!</h2>
|
||||
<p>Please click the link below to verify your email address and activate your account:</p>
|
||||
<p><a href="{verify_link}">{verify_link}</a></p>
|
||||
<p>If you did not request this, please ignore this email.</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
msg.attach(MIMEText(body, 'html'))
|
||||
|
||||
try:
|
||||
# Use SMTP_SSL for port 465
|
||||
server = smtplib.SMTP_SSL(smtp_host, smtp_port)
|
||||
server.login(smtp_user, smtp_password)
|
||||
server.send_message(msg)
|
||||
server.quit()
|
||||
print(f"Verification email sent to {to_email}")
|
||||
except Exception as e:
|
||||
print(f"Failed to send email: {e}")
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
@@ -15,3 +15,16 @@ class User(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
projects = relationship("Project", back_populates="owner")
|
||||
email_verifications = relationship("EmailVerification", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
class EmailVerification(Base):
|
||||
__tablename__ = "email_verifications"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
token_hash = Column(String, index=True, nullable=False)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
is_used = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
user = relationship("User", back_populates="email_verifications")
|
||||
|
||||
@@ -18,6 +18,9 @@ class UserUpdate(BaseModel):
|
||||
is_admin: Optional[bool] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
class ResendVerificationRequest(BaseModel):
|
||||
username: str
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
+6
-1
@@ -4,6 +4,11 @@ import binascii
|
||||
from typing import Any, Dict, List, Optional, Literal, Tuple
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载项目根目录下的 .env 文件
|
||||
env_path = Path(__file__).resolve().parent.parent / ".env"
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
@@ -35,7 +40,7 @@ from app.context import (
|
||||
from app.services.knowledge_index import knowledge_index_service
|
||||
from app.database import engine, Base
|
||||
# Import all models to ensure they are registered
|
||||
from app.models.user import User
|
||||
from app.models.user import User, EmailVerification
|
||||
from app.models.project import Project
|
||||
from app.models.datasource import DataSource
|
||||
from app.models.subagent import Subagent
|
||||
|
||||
@@ -31,6 +31,7 @@ dependencies = [
|
||||
"psycopg2-binary>=2.9.11",
|
||||
"pydantic>=2.12.0,<3.0.0",
|
||||
"pydantic-settings>=2.12.0,<3.0.0",
|
||||
"python-dotenv>=1.0.1",
|
||||
"python-jose[cryptography]>=3.5.0",
|
||||
"python-multipart>=0.0.22",
|
||||
"python-socketio>=5.16.0,<6.0.0",
|
||||
|
||||
Generated
+2
@@ -245,6 +245,7 @@ dependencies = [
|
||||
{ name = "psycopg2-binary" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "python-socketio" },
|
||||
@@ -292,6 +293,7 @@ requires-dist = [
|
||||
{ name = "psycopg2-binary", specifier = ">=2.9.11" },
|
||||
{ name = "pydantic", specifier = ">=2.12.0,<3.0.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.12.0,<3.0.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.1" },
|
||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.5.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.22" },
|
||||
{ name = "python-socketio", specifier = ">=5.16.0,<6.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user