正在加载,请稍候…

Django REST Framework 生产环境实战:认证、限流与部署

构建生产级 Django REST API,涵盖 Simple JWT 认证、权限类、限流、序列化器优化、Celery 异步任务及 Docker 部署。

Django REST Framework 生产环境实战:认证、限流与部署

Django REST Framework 生产环境配置

DRF 成熟且久经考验。以下是如何为真实生产流量配置它。

Django REST Framework 生产环境实战:认证、限流与部署 插图

设置配置

# settings/production.py
from .base import *
import os

DEBUG = False
ALLOWED_HOSTS = os.getenv('ALLOWED_HOSTS', '').split(',')

INSTALLED_APPS += [
    'rest_framework',
    'rest_framework_simplejwt',
    'corsheaders',
    'django_filters',
]

REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework_simplejwt.authentication.JWTAuthentication',
    ],
    'DEFAULT_PERMISSION_CLASSES': [
        'rest_framework.permissions.IsAuthenticated',
    ],
    'DEFAULT_THROTTLE_CLASSES': [
        'rest_framework.throttling.AnonRateThrottle',
        'rest_framework.throttling.UserRateThrottle',
    ],
    'DEFAULT_THROTTLE_RATES': {
        'anon': '100/hour',
        'user': '1000/hour',
    },
    'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.CursorPagination',
    'PAGE_SIZE': 20,
    'DEFAULT_FILTER_BACKENDS': [
        'django_filters.rest_framework.DjangoFilterBackend',
        'rest_framework.filters.SearchFilter',
        'rest_framework.filters.OrderingFilter',
    ],
}

SIMPLE_JWT = {
    'ACCESS_TOKEN_LIFETIME': timedelta(minutes=15),
    'REFRESH_TOKEN_LIFETIME': timedelta(days=7),
    'ROTATE_REFRESH_TOKENS': True,
    'BLACKLIST_AFTER_ROTATION': True,
    'AUTH_HEADER_TYPES': ('Bearer',),
}

序列化器优化

# serializers.py
from rest_framework import serializers
from .models import User, Post

class UserSerializer(serializers.ModelSerializer):
    post_count = serializers.IntegerField(read_only=True)
    
    class Meta:
        model = User
        fields = ['id', 'username', 'email', 'post_count', 'created_at']
        read_only_fields = ['id', 'created_at']

class CreateUserSerializer(serializers.ModelSerializer):
    password = serializers.CharField(write_only=True, min_length=8)
    
    class Meta:
        model = User
        fields = ['username', 'email', 'password']
    
    def validate_email(self, value):
        if User.objects.filter(email=value).exists():
            raise serializers.ValidationError("Email already taken")
        return value
    
    def create(self, validated_data):
        return User.objects.create_user(**validated_data)

class PostSerializer(serializers.ModelSerializer):
    author = UserSerializer(read_only=True)
    
    class Meta:
        model = Post
        fields = ['id', 'title', 'content', 'author', 'published_at']

Django REST Framework 生产环境实战:认证、限流与部署 插图

ViewSets

# views.py
from rest_framework import viewsets, filters
from rest_framework.decorators import action
from rest_framework.response import Response
from django.db.models import Count
from django_filters.rest_framework import DjangoFilterBackend

class UserViewSet(viewsets.ModelViewSet):
    serializer_class = UserSerializer
    permission_classes = [IsAuthenticatedOrReadOnly]
    filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter]
    search_fields = ['username', 'email']
    ordering_fields = ['created_at', 'username']
    filterset_fields = ['is_active', 'role']
    
    def get_queryset(self):
        # 使用 annotate 避免 N+1 查询
        return User.objects.annotate(
            post_count=Count('posts')
        ).select_related('profile').prefetch_related('groups')
    
    def get_serializer_class(self):
        if self.action == 'create':
            return CreateUserSerializer
        return UserSerializer
    
    @action(detail=True, methods=['post'])
    def change_password(self, request, pk=None):
        user = self.get_object()
        serializer = ChangePasswordSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        user.set_password(serializer.validated_data['new_password'])
        user.save()
        return Response({'status': 'password changed'})
    
    @action(detail=False, methods=['get'])
    def me(self, request):
        serializer = self.get_serializer(request.user)
        return Response(serializer.data)

自定义权限

from rest_framework.permissions import BasePermission, SAFE_METHODS

class IsOwnerOrReadOnly(BasePermission):
    def has_object_permission(self, request, view, obj):
        if request.method in SAFE_METHODS:
            return True
        return obj.owner == request.user

class IsAdminOrSelf(BasePermission):
    def has_object_permission(self, request, view, obj):
        return request.user.is_staff or obj == request.user

Django REST Framework 生产环境实战:认证、限流与部署 插图

Celery 后台任务

# tasks.py
from celery import shared_task
from celery.utils.log import get_task_logger
from django.core.mail import send_mail

logger = get_task_logger(__name__)

@shared_task(bind=True, max_retries=3)
def send_welcome_email(self, user_id: int):
    try:
        user = User.objects.get(id=user_id)
        send_mail(
            subject='Welcome!',
            message=f'Hello {user.username}, welcome aboard!',
            from_email='noreply@example.com',
            recipient_list=[user.email],
        )
        logger.info(f'Welcome email sent to {user.email}')
    except User.DoesNotExist:
        logger.error(f'User {user_id} not found')
    except Exception as exc:
        logger.error(f'Failed to send email: {exc}')
        raise self.retry(exc=exc, countdown=60)

# 调用任务
send_welcome_email.delay(user.id)

URL 路由

# urls.py
from rest_framework.routers import DefaultRouter
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView

router = DefaultRouter()
router.register(r'users', UserViewSet)
router.register(r'posts', PostViewSet)

urlpatterns = [
    path('api/v1/', include(router.urls)),
    path('api/token/', TokenObtainPairView.as_view()),
    path('api/token/refresh/', TokenRefreshView.as_view()),
]

查询优化

# 避免 N+1:使用 select_related 和 prefetch_related
class PostViewSet(viewsets.ModelViewSet):
    def get_queryset(self):
        return Post.objects.select_related(
            'author',
            'author__profile',
        ).prefetch_related(
            'tags',
            Prefetch('comments', queryset=Comment.objects.filter(approved=True)),
        ).only(
            'id', 'title', 'slug', 'published_at',
            'author__id', 'author__username',
        )