Skip to main content
  1. Posts/

久仰大名:基于 FastAPI 实现 OAuth2 登录认证

·

写 Pianoforte 时第一个遇到的可复用模块。

最近在通过项目学习 FastAPI,发现登录认证模块可以复用,于是记录一下。代码主要修改自 官方文档

目录结构 #

.
|--app.db
|--requirements.txt
|--app/
|  |--__init__.py
|  |--config.py
|  |--db.py
|  |--main.py
|  |--models.py
|  |--projects/
|  |--users/
|  |  |--__init__.py
|  |  |--auth.py
|  |  |--config.py
|  |  |--crud.py
|  |  |--schemas.py
|  |  |--users.py

最外层是 sqlite 数据库文件和依赖,projects/ 目录和主题无关因此未列出。

现有代码 #

main.py #

from fastapi import Depends, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware

from .db import engine
from . import models, config

from .users import users, auth
from .projects import projects

models.Base.metadata.create_all(bind=engine)

app = FastAPI(
    title='Pianoforte',
    description='Pianoforte Is AN Offensive Framework Of Red TEam',
    version='0.1.0'
)
app.add_middleware(
    CORSMiddleware,
    allow_origins=config.allow_origins,
    allow_credentials=True,
    allow_methods=config.allow_methods,
    allow_headers=['*']
)

app.include_router(users.router, prefix='/users', tags=['users'])

auth_needed = [Depends(auth.get_current_user)]

app.include_router(projects.router, prefix='/projects', tags=['projects'], dependencies=auth_needed)

首先根据 models.py 中的定义创建了数据库,数据库相关代码在 db.py 中。随后添加 CORS 中间件,并添加 users 路由和 projects 路由,后者需要认证后方能访问。我们通过 Dependency 的方式,要求 auth.get_current_user 函数运行正常才能访问 projects 路由,由此实现权限管理。

db.py #

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

SQLALCHEMY_DATABASE_URL = "sqlite:///./app.db"
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()


def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

由于项目规模不大,选用 sqlite3 数据库配合 sqlalchemy 的 ORM 功能进行数据库操作。

此处 yield 方式返回数据库 Session 需要 Python 3.7 及以上。

models.py #

from sqlalchemy import Column, Integer, String

from .db import Base


class User(Base):
    __tablename__ = 'users'

    id = Column(Integer, primary_key=True, index=True)
    username = Column(String, unique=True)
    password = Column(String)

这里仅展示了相关代码,即 User 类的字段。这一段代码在数据库中创建了 users 表来存放用户数据。

由于业务需要,限定 username 不能重复。

users/schemas.py #

上面的 models.py 定义了数据库中的 User 类,而在本文件中定义了两种语义下的 User 类:

from pydantic import BaseModel


class UserBase(BaseModel):
    username: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int

    class Config:
        orm_mode = True

其一是在注册与登录时使用的 UserCreate 类,此时不需要传输 id;其二是在获取用户信息时返回的 User 类,此时不需要传输 password 但需要 id

users/crud.py #

from sqlalchemy.orm import Session

from ..models import User
from . import schemas


def get(db: Session, username: str):
    return db.query(User).filter_by(username=username).first()


def create(db: Session, param: schemas.UserCreate):
    target = User(**param.dict())
    db.add(target)
    db.commit()
    db.refresh(target)
    return target

关于数据库操作,这里只涉及到增和查操作。由于 username 唯一,可以使用 username 作为关键字查询。

在插入数据前,需要将用户传入的数据解构并构建 User 数据库对象。

认证功能实现 #

我们计划用 OAuth2 的用户名密码认证,并将用户密码哈希后存入数据库。为了维护登陆状态,这里采用 jwt 代替较为繁琐的 cookie 机制。

注册 #

注册部分较为简单(验证码等部分与主题无关,未展示),在 users.py 中编写路由函数:

from fastapi import APIRouter, Depends, HTTPException

from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from . import crud
from .auth import get_current_user, auth_user, gen_token, hash_password
from .schemas import User, UserCreate

from ..db import get_db

router = APIRouter()
# ...
@router.post('/')
def create(param: UserCreate, db: Session = Depends(get_db)):
    param.password = hash_password(param.password)
    try:
        target = crud.create(db=db, param=param)
    except IntegrityError:
        raise HTTPException(status_code=400, detail='Duplicate username')
    return 0

利用 Depends(get_db) 获取一个数据库 Session,然后借助 crud.create 插入密码被哈希后的用户数据。这里的 hash_password 来源于 auth.py

from passlib.context import CryptContext

pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')

def hash_password(password):
    return pwd_context.hash(password)

我们采用了 bcrypt 算法进行哈希。同理,验证函数同样简单:

def verify_password(plain, hashed):
    return pwd_context.verify(plain, hashed)

登录:验证用户信息 #

对于给定的 usernamepassword,我们需要通过数据库查询验证用户名和密码是否合法:

from sqlalchemy.orm import Session
from . import crud

def auth_user(db: Session, username: str, password: str):
    user = crud.get(db=db, username=username)
    if user is None:
        return False
    if not verify_password(password, user.password):
        return False
    return user

根据 OAuth2 标准,接收用户名和密码时需要使用 application/x-www-form-urlencoded 格式,当用户名或密码错误时返回 401,且返回自定义 HTTP 头 WWW-Authenticate: Bearer,其中 Bearer 是我们这里使用的携带 token 的方式:

from typing import Dict
from fastapi.security import OAuth2PasswordRequestForm

@router.post('/login', response_model=Dict[str, str])
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    user = auth_user(db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=401, detail='Incorrect username or password', headers={'WWW-Authenticate': 'Bearer'})
    # ...

登录:生成 jwt #

对于给定的数据以及 token 失效时间,借助 python-jose 库,参考 jwt 标准生成 jwt:

from jose import JWTError, jwt
from datetime import datetime, timedelta

from .config import SECRET_KEY, ALGORITHM

def gen_token(data: dict, expires_delta: timedelta = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({'exp': expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

    return encoded_jwt

这里的 data,根据 jwt 标准可以设置 sub 字段为用户的 username,失效时间也可以在程序中自定义:

from .config import ACCESS_TOKEN_EXPIRE_MINUTES

@router.post('/login', response_model=Dict[str, str])
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    user = auth_user(db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=401, detail='Incorrect username or password', headers={'WWW-Authenticate': 'Bearer'})
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = gen_token(
        data={'sub': user.username}, expires_delta=access_token_expires)

    return {
        'access_token': access_token,
        'token_type': 'bearer'
    }

最后根据 OAuth2 标准返回 {access_token, token_type} 对象。

根据 jwt 获取用户信息 #

先通过 jwt.decode 解码信息并验证签名,随后对得到的信息进行解析,最后回到数据库中验证解析后的信息:

from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer

from ..db import get_db

oauth2_scheme = OAuth2PasswordBearer(tokenUrl='users/login')

def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
    credentials_exception = HTTPException(
        status_code=401, detail='Could not validate credentials', headers={'WWW-Authenticate': 'Bearer'})
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get('sub')
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = crud.get(db=db, username=username)
    if user is None:
        raise credentials_exception

    return user

这里 OAuth2PasswordBearer(tokenUrl='users/login') 实际上和 login 函数的参数 form_data: OAuth2PasswordRequestForm 对应,形成完整的 OAuth2 password flow。

然后在 users.py 中暴露获取当前登陆用户信息的接口:

@router.get('/', response_model=User)
async def get(current_user: User = Depends(get_current_user)):
    return current_user

测试 #

app/ 所在目录运行:

python3 -m uvicorn app.main:app --reload

随后访问 http://localhost:8000/docs 即可看到 swagger 文档。

图 1

图 2