121 lines
3.3 KiB
Go
121 lines
3.3 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
|
||
"nyanimedb/auth"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
)
|
||
|
||
// ctxKey — приватный тип для ключа контекста
|
||
type ctxKey struct{}
|
||
|
||
// ginContextKey — уникальный ключ для хранения *gin.Context
|
||
var ginContextKey = &ctxKey{}
|
||
|
||
// GinContextToContext сохраняет *gin.Context в context.Context запроса
|
||
func GinContextToContext(c *gin.Context) {
|
||
ctx := context.WithValue(c.Request.Context(), ginContextKey, c)
|
||
c.Request = c.Request.WithContext(ctx)
|
||
}
|
||
|
||
// GinContextFromContext извлекает *gin.Context из context.Context
|
||
func GinContextFromContext(ctx context.Context) (*gin.Context, bool) {
|
||
ginCtx, ok := ctx.Value(ginContextKey).(*gin.Context)
|
||
return ginCtx, ok
|
||
}
|
||
|
||
func JWTAuthMiddleware(secret string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 1. Получаем access_token из cookie
|
||
tokenStr, err := c.Cookie("access_token")
|
||
if err != nil {
|
||
abortWithJSON(c, http.StatusUnauthorized, "missing access_token cookie")
|
||
return
|
||
}
|
||
|
||
// 2. Парсим токен с MapClaims
|
||
token, err := jwt.ParseWithClaims(tokenStr, &auth.TokenClaims{}, func(t *jwt.Token) (interface{}, error) {
|
||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, fmt.Errorf("unexpected signing method")
|
||
}
|
||
return []byte(secret), nil
|
||
})
|
||
// token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
|
||
// if t.Method != jwt.SigningMethodHS256 {
|
||
// return nil, errors.New("unexpected signing method: " + t.Method.Alg())
|
||
// }
|
||
// return []byte(secret), nil // ← конвертируем string → []byte
|
||
// })
|
||
if err != nil {
|
||
abortWithJSON(c, http.StatusUnauthorized, "invalid token: "+err.Error())
|
||
return
|
||
}
|
||
|
||
// 3. Проверяем валидность
|
||
if !token.Valid {
|
||
abortWithJSON(c, http.StatusUnauthorized, "token is invalid")
|
||
return
|
||
}
|
||
|
||
// 4. Извлекаем user_id из claims
|
||
claims, ok := token.Claims.(*auth.TokenClaims)
|
||
if !ok {
|
||
abortWithJSON(c, http.StatusUnauthorized, "invalid claims format")
|
||
return
|
||
}
|
||
|
||
if claims.Subject == "" {
|
||
abortWithJSON(c, http.StatusUnauthorized, "user_id claim missing or invalid")
|
||
return
|
||
}
|
||
if claims.Type != "access" {
|
||
abortWithJSON(c, http.StatusUnauthorized, "token type is not access")
|
||
return
|
||
}
|
||
|
||
// 5. Сохраняем в контексте
|
||
c.Set("user_id", claims.Subject)
|
||
|
||
// 6. Для oapi-codegen — кладём gin.Context в request context
|
||
GinContextToContext(c)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// Вспомогательные функции (без изменений)
|
||
func UserIDFromGin(c *gin.Context) (string, bool) {
|
||
id, exists := c.Get("user_id")
|
||
if !exists {
|
||
return "", false
|
||
}
|
||
if s, ok := id.(string); ok {
|
||
return s, true
|
||
}
|
||
return "", false
|
||
}
|
||
|
||
func UserIDFromContext(ctx context.Context) (string, error) {
|
||
ginCtx, ok := GinContextFromContext(ctx)
|
||
if !ok {
|
||
return "", errors.New("gin context not found")
|
||
}
|
||
userID, ok := UserIDFromGin(ginCtx)
|
||
if !ok {
|
||
return "", errors.New("user_id not found in context")
|
||
}
|
||
return userID, nil
|
||
}
|
||
|
||
func abortWithJSON(c *gin.Context, code int, message string) {
|
||
c.AbortWithStatusJSON(code, gin.H{
|
||
"error": "unauthorized",
|
||
"message": message,
|
||
})
|
||
}
|