feat: add email verification
This commit is contained in:
@@ -1,19 +1,28 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app.database import get_db, engine, Base
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserResponse
|
||||
from app.models.user import User, EmailVerification
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserResponse, ResendVerificationRequest
|
||||
from app.core.security import get_password_hash, verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
from datetime import timedelta
|
||||
from app.core.email import send_verification_email
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def generate_verification_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
@router.post("/auth/login")
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
@@ -45,7 +54,7 @@ def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depend
|
||||
}
|
||||
|
||||
@router.post("/auth/register", response_model=UserResponse)
|
||||
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
def register_user(user: UserCreate, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
db_user = db.query(User).filter(User.username == user.username).first()
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
@@ -58,19 +67,92 @@ def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
|
||||
# If this is the first user, make them an admin
|
||||
is_first_user = db.query(User).count() == 0
|
||||
is_admin = is_first_user or user.is_admin
|
||||
is_active = True if is_first_user else False
|
||||
|
||||
db_user = User(
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
hashed_password=hashed_password,
|
||||
is_active=True,
|
||||
is_admin=is_first_user or user.is_admin
|
||||
is_active=is_active,
|
||||
is_admin=is_admin
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
if not is_active:
|
||||
token = generate_verification_token()
|
||||
hashed = hash_token(token)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
verification = EmailVerification(
|
||||
user_id=db_user.id,
|
||||
token_hash=hashed,
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(verification)
|
||||
db.commit()
|
||||
|
||||
# 将用户的 email 保存到局部变量中,防止在后台任务执行前 session 关闭导致延迟加载失败
|
||||
user_email = db_user.email
|
||||
background_tasks.add_task(send_verification_email, user_email, token)
|
||||
|
||||
return db_user
|
||||
|
||||
@router.get("/auth/verify-email")
|
||||
def verify_email(token: str, db: Session = Depends(get_db)):
|
||||
hashed = hash_token(token)
|
||||
verification = db.query(EmailVerification).filter(
|
||||
EmailVerification.token_hash == hashed,
|
||||
EmailVerification.is_used == False
|
||||
).first()
|
||||
|
||||
if not verification:
|
||||
raise HTTPException(status_code=400, detail="Invalid or used token")
|
||||
|
||||
# Check if expired (make timezone-aware if naive)
|
||||
expires_at = verification.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
if expires_at < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=400, detail="Token expired")
|
||||
|
||||
user = db.query(User).filter(User.id == verification.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user.is_active = True
|
||||
verification.is_used = True
|
||||
db.commit()
|
||||
|
||||
return {"status": "success", "message": "Email verified successfully"}
|
||||
|
||||
@router.post("/auth/resend-verification")
|
||||
def resend_verification(request: ResendVerificationRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
user = db.query(User).filter(User.username == request.username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user.is_active:
|
||||
raise HTTPException(status_code=400, detail="User already active")
|
||||
|
||||
token = generate_verification_token()
|
||||
hashed = hash_token(token)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
verification = EmailVerification(
|
||||
user_id=user.id,
|
||||
token_hash=hashed,
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(verification)
|
||||
db.commit()
|
||||
|
||||
# 提取 email,避免后台任务访问已断开的 db session
|
||||
user_email = user.email
|
||||
background_tasks.add_task(send_verification_email, user_email, token)
|
||||
return {"status": "success", "message": "Verification email sent"}
|
||||
|
||||
@router.get("/users", response_model=List[UserResponse])
|
||||
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
users = db.query(User).offset(skip).limit(limit).all()
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import smtplib
|
||||
import os
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
def send_verification_email(to_email: str, token: str):
|
||||
smtp_host = os.getenv("SMTP_HOST", "smtp.qq.com")
|
||||
smtp_port = int(os.getenv("SMTP_PORT", "465"))
|
||||
smtp_user = os.getenv("SMTP_USER", "")
|
||||
smtp_password = os.getenv("SMTP_PASSWORD", "")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173")
|
||||
|
||||
if not smtp_user or not smtp_password:
|
||||
print("SMTP configuration is missing. Skip sending email.")
|
||||
return
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = smtp_user
|
||||
msg['To'] = to_email
|
||||
msg['Subject'] = "Please verify your email address"
|
||||
|
||||
verify_link = f"{frontend_url}/verify-email?token={token}"
|
||||
body = f"""
|
||||
<html>
|
||||
<body>
|
||||
<h2>Welcome to DataClaw!</h2>
|
||||
<p>Please click the link below to verify your email address and activate your account:</p>
|
||||
<p><a href="{verify_link}">{verify_link}</a></p>
|
||||
<p>If you did not request this, please ignore this email.</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
msg.attach(MIMEText(body, 'html'))
|
||||
|
||||
try:
|
||||
# Use SMTP_SSL for port 465
|
||||
server = smtplib.SMTP_SSL(smtp_host, smtp_port)
|
||||
server.login(smtp_user, smtp_password)
|
||||
server.send_message(msg)
|
||||
server.quit()
|
||||
print(f"Verification email sent to {to_email}")
|
||||
except Exception as e:
|
||||
print(f"Failed to send email: {e}")
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
@@ -15,3 +15,16 @@ class User(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
projects = relationship("Project", back_populates="owner")
|
||||
email_verifications = relationship("EmailVerification", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
class EmailVerification(Base):
|
||||
__tablename__ = "email_verifications"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
token_hash = Column(String, index=True, nullable=False)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
is_used = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
user = relationship("User", back_populates="email_verifications")
|
||||
|
||||
@@ -18,6 +18,9 @@ class UserUpdate(BaseModel):
|
||||
is_admin: Optional[bool] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
class ResendVerificationRequest(BaseModel):
|
||||
username: str
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
Reference in New Issue
Block a user