diff --git a/auth/auth.gen.go b/auth/auth.gen.go index 1e8803e..fd7a224 100644 --- a/auth/auth.gen.go +++ b/auth/auth.gen.go @@ -56,6 +56,9 @@ type ServerInterface interface { // Get service impersontaion token // (POST /get-impersonation-token) GetImpersonationToken(c *gin.Context) + // Refreshes access_token and refresh_token + // (GET /refresh-tokens) + RefreshTokens(c *gin.Context) // Sign in a user and return JWT // (POST /sign-in) PostSignIn(c *gin.Context) @@ -88,6 +91,19 @@ func (siw *ServerInterfaceWrapper) GetImpersonationToken(c *gin.Context) { siw.Handler.GetImpersonationToken(c) } +// RefreshTokens operation middleware +func (siw *ServerInterfaceWrapper) RefreshTokens(c *gin.Context) { + + for _, middleware := range siw.HandlerMiddlewares { + middleware(c) + if c.IsAborted() { + return + } + } + + siw.Handler.RefreshTokens(c) +} + // PostSignIn operation middleware func (siw *ServerInterfaceWrapper) PostSignIn(c *gin.Context) { @@ -142,10 +158,17 @@ func RegisterHandlersWithOptions(router gin.IRouter, si ServerInterface, options } router.POST(options.BaseURL+"/get-impersonation-token", wrapper.GetImpersonationToken) + router.GET(options.BaseURL+"/refresh-tokens", wrapper.RefreshTokens) router.POST(options.BaseURL+"/sign-in", wrapper.PostSignIn) router.POST(options.BaseURL+"/sign-up", wrapper.PostSignUp) } +type ClientErrorResponse struct { +} + +type ServerErrorResponse struct { +} + type UnauthorizedErrorResponse struct { } @@ -176,6 +199,42 @@ func (response GetImpersonationToken401Response) VisitGetImpersonationTokenRespo return nil } +type RefreshTokensRequestObject struct { +} + +type RefreshTokensResponseObject interface { + VisitRefreshTokensResponse(w http.ResponseWriter) error +} + +type RefreshTokens200Response struct { +} + +func (response RefreshTokens200Response) VisitRefreshTokensResponse(w http.ResponseWriter) error { + w.WriteHeader(200) + return nil +} + +type RefreshTokens400Response = ClientErrorResponse + +func (response RefreshTokens400Response) VisitRefreshTokensResponse(w http.ResponseWriter) error { + w.WriteHeader(400) + return nil +} + +type RefreshTokens401Response = UnauthorizedErrorResponse + +func (response RefreshTokens401Response) VisitRefreshTokensResponse(w http.ResponseWriter) error { + w.WriteHeader(401) + return nil +} + +type RefreshTokens500Response = ServerErrorResponse + +func (response RefreshTokens500Response) VisitRefreshTokensResponse(w http.ResponseWriter) error { + w.WriteHeader(500) + return nil +} + type PostSignInRequestObject struct { Body *PostSignInJSONRequestBody } @@ -227,6 +286,9 @@ type StrictServerInterface interface { // Get service impersontaion token // (POST /get-impersonation-token) GetImpersonationToken(ctx context.Context, request GetImpersonationTokenRequestObject) (GetImpersonationTokenResponseObject, error) + // Refreshes access_token and refresh_token + // (GET /refresh-tokens) + RefreshTokens(ctx context.Context, request RefreshTokensRequestObject) (RefreshTokensResponseObject, error) // Sign in a user and return JWT // (POST /sign-in) PostSignIn(ctx context.Context, request PostSignInRequestObject) (PostSignInResponseObject, error) @@ -280,6 +342,31 @@ func (sh *strictHandler) GetImpersonationToken(ctx *gin.Context) { } } +// RefreshTokens operation middleware +func (sh *strictHandler) RefreshTokens(ctx *gin.Context) { + var request RefreshTokensRequestObject + + handler := func(ctx *gin.Context, request interface{}) (interface{}, error) { + return sh.ssi.RefreshTokens(ctx, request.(RefreshTokensRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "RefreshTokens") + } + + response, err := handler(ctx, request) + + if err != nil { + ctx.Error(err) + ctx.Status(http.StatusInternalServerError) + } else if validResponse, ok := response.(RefreshTokensResponseObject); ok { + if err := validResponse.VisitRefreshTokensResponse(ctx.Writer); err != nil { + ctx.Error(err) + } + } else if response != nil { + ctx.Error(fmt.Errorf("unexpected response type: %T", response)) + } +} + // PostSignIn operation middleware func (sh *strictHandler) PostSignIn(ctx *gin.Context) { var request PostSignInRequestObject diff --git a/auth/claims.go b/auth/claims.go new file mode 100644 index 0000000..d888a1b --- /dev/null +++ b/auth/claims.go @@ -0,0 +1,10 @@ +package auth + +import "github.com/golang-jwt/jwt/v5" + +type TokenClaims struct { + UserID string `json:"user_id"` + Type string `json:"type"` + ImpID *string `json:"imp_id,omitempty"` + jwt.RegisteredClaims +} diff --git a/auth/openapi-auth.yaml b/auth/openapi-auth.yaml index 803a4ae..e95e8c2 100644 --- a/auth/openapi-auth.yaml +++ b/auth/openapi-auth.yaml @@ -116,6 +116,22 @@ paths: "401": $ref: '#/components/responses/UnauthorizedError' + /refresh-tokens: + get: + summary: Refreshes access_token and refresh_token + operationId: refreshTokens + tags: [Auth] + responses: + # This one sets two cookies: access_token and refresh_token + "200": + description: Refresh success + "400": + $ref: '#/components/responses/ClientError' + "401": + $ref: '#/components/responses/UnauthorizedError' + "500": + $ref: '#/components/responses/ServerError' + components: securitySchemes: bearerAuth: @@ -123,4 +139,8 @@ components: scheme: bearer responses: UnauthorizedError: - description: Access token is missing or invalid \ No newline at end of file + description: Access token is missing or invalid + ServerError: + description: ServerError + ClientError: + description: ClientError \ No newline at end of file diff --git a/modules/auth/handlers/handlers.go b/modules/auth/handlers/handlers.go index 3af44f3..4f67448 100644 --- a/modules/auth/handlers/handlers.go +++ b/modules/auth/handlers/handlers.go @@ -47,28 +47,35 @@ func CheckPassword(password, hash string) (bool, error) { return argon2id.ComparePasswordAndHash(password, hash) } -func (s Server) generateImpersonationToken(userID string, impersonated_by string) (accessToken string, err error) { - accessClaims := jwt.MapClaims{ - "user_id": userID, - "exp": time.Now().Add(15 * time.Minute).Unix(), - "imp_id": impersonated_by, +func (s *Server) generateImpersonationToken(userID string, impersonatedBy string) (string, error) { + now := time.Now() + claims := auth.TokenClaims{ + UserID: userID, + ImpID: &impersonatedBy, + Type: "access", + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)), + ID: generateJTI(), + }, } - at := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims) - - accessToken, err = at.SignedString([]byte(s.JwtPrivateKey)) - if err != nil { - return "", err - } - - return accessToken, nil + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(s.JwtPrivateKey)) } -func (s Server) generateTokens(userID string) (accessToken string, refreshToken string, csrfToken string, err error) { - accessClaims := jwt.MapClaims{ - "user_id": userID, - "exp": time.Now().Add(15 * time.Minute).Unix(), - //TODO: add created_at +func (s *Server) generateTokens(userID string) (accessToken string, refreshToken string, csrfToken string, err error) { + now := time.Now() + + // Access token (15 мин) + accessClaims := auth.TokenClaims{ + UserID: userID, + Type: "access", + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)), + ID: generateJTI(), + }, } at := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims) accessToken, err = at.SignedString([]byte(s.JwtPrivateKey)) @@ -76,9 +83,15 @@ func (s Server) generateTokens(userID string) (accessToken string, refreshToken return "", "", "", err } - refreshClaims := jwt.MapClaims{ - "user_id": userID, - "exp": time.Now().Add(7 * 24 * time.Hour).Unix(), + // Refresh token (7 дней) + refreshClaims := auth.TokenClaims{ + UserID: userID, + Type: "refresh", + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(7 * 24 * time.Hour)), + ID: generateJTI(), + }, } rt := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims) refreshToken, err = rt.SignedString([]byte(s.JwtPrivateKey)) @@ -86,6 +99,7 @@ func (s Server) generateTokens(userID string) (accessToken string, refreshToken return "", "", "", err } + // CSRF token csrfBytes := make([]byte, 32) _, err = rand.Read(csrfBytes) if err != nil { @@ -219,56 +233,56 @@ func (s Server) GetImpersonationToken(ctx context.Context, req auth.GetImpersona return auth.GetImpersonationToken200JSONResponse{AccessToken: accessToken}, nil } -// func (s Server) PostAuthRefreshToken(ctx context.Context, req auth.PostAuthRefreshTokenRequestObject) (auth.PostAuthRefreshTokenResponseObject, error) { -// valid := false -// var userID *string -// var errStr *string +func (s Server) RefreshTokens(ctx context.Context, req auth.RefreshTokensRequestObject) (auth.RefreshTokensResponseObject, error) { + ginCtx, ok := ctx.Value(gin.ContextKey).(*gin.Context) + if !ok { + log.Print("failed to get gin context") + return auth.RefreshTokens500Response{}, fmt.Errorf("failed to get gin.Context from context.Context") + } -// token, err := jwt.Parse(req.Body.Token, func(t *jwt.Token) (interface{}, error) { -// if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { -// return nil, fmt.Errorf("unexpected signing method") -// } -// return refreshSecret, nil -// }) + rtCookie, err := ginCtx.Request.Cookie("refresh_token") + if err != nil { + log.Print("failed to get refresh_token cookie") + return auth.RefreshTokens400Response{}, fmt.Errorf("failed to get refresh_token cookie") + } -// if err != nil { -// e := err.Error() -// errStr = &e -// return auth.PostAuthVerifyToken200JSONResponse{ -// Valid: &valid, -// UserId: userID, -// Error: errStr, -// }, nil -// } + refreshToken := rtCookie.Value -// if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { -// if uid, ok := claims["user_id"].(string); ok { -// // Refresh token is valid, generate new tokens -// newAccessToken, newRefreshToken, _ := generateTokens(uid) -// valid = true -// userID = &uid -// return auth.PostAuthVerifyToken200JSONResponse{ -// Valid: &valid, -// UserId: userID, -// Error: nil, -// Token: &newAccessToken, // return new access token -// // optionally return newRefreshToken as well -// }, nil -// } else { -// e := "user_id not found in refresh token" -// errStr = &e -// } -// } else { -// e := "invalid refresh token claims" -// errStr = &e -// } + token, err := jwt.ParseWithClaims(refreshToken, &auth.TokenClaims{}, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method") + } + return []byte(s.JwtPrivateKey), nil + }) + if err != nil || !token.Valid { + log.Print("invalid refresh token") + return auth.RefreshTokens401Response{}, nil + } -// return auth.PostAuthVerifyToken200JSONResponse{ -// Valid: &valid, -// UserId: userID, -// Error: errStr, -// }, nil -// } + claims, ok := token.Claims.(*auth.TokenClaims) + if !ok || claims.UserID == "" { + log.Print("invalid refresh token claims") + return auth.RefreshTokens401Response{}, nil + } + if claims.Type != "refresh" { + log.Errorf("token is not a refresh token") + return auth.RefreshTokens401Response{}, nil + } + + accessToken, refreshToken, csrfToken, err := s.generateTokens(claims.UserID) + if err != nil { + log.Errorf("failed to generate tokens for user %s: %v", claims.UserID, err) + return auth.RefreshTokens500Response{}, nil + } + + // TODO: check cookie settings carefully + ginCtx.SetSameSite(http.SameSiteStrictMode) + ginCtx.SetCookie("access_token", accessToken, 900, "/api", "", false, true) + ginCtx.SetCookie("refresh_token", refreshToken, 1209600, "/auth", "", false, true) + ginCtx.SetCookie("xsrf_token", csrfToken, 1209600, "/", "", false, false) + + return auth.RefreshTokens200Response{}, nil +} func ExtractBearerToken(header string) (string, error) { const prefix = "Bearer " @@ -277,3 +291,9 @@ func ExtractBearerToken(header string) (string, error) { } return header[len(prefix):], nil } + +func generateJTI() string { + b := make([]byte, 16) + _, _ = rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} diff --git a/modules/auth/main.go b/modules/auth/main.go index 7305b7d..bbeb014 100644 --- a/modules/auth/main.go +++ b/modules/auth/main.go @@ -46,7 +46,7 @@ func main() { log.Info("allow origins:", AppConfig.ServiceAddress) r.Use(cors.New(cors.Config{ - AllowOrigins: []string{"*"}, + AllowOrigins: []string{AppConfig.ServiceAddress}, AllowMethods: []string{"GET", "POST", "PUT", "DELETE"}, AllowHeaders: []string{"Origin", "Content-Type", "Accept"}, ExposeHeaders: []string{"Content-Length"}, diff --git a/modules/backend/middlewares/access.go b/modules/backend/middlewares/access.go index 73200e8..8e787f8 100644 --- a/modules/backend/middlewares/access.go +++ b/modules/backend/middlewares/access.go @@ -3,8 +3,11 @@ package middleware import ( "context" "errors" + "fmt" "net/http" + "nyanimedb/auth" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" ) @@ -37,12 +40,18 @@ func JWTAuthMiddleware(secret string) gin.HandlerFunc { } // 2. Парсим токен с MapClaims - token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { - if t.Method != jwt.SigningMethodHS256 { - return nil, errors.New("unexpected signing method: " + t.Method.Alg()) + token, err := jwt.ParseWithClaims(tokenStr, &auth.TokenClaims{}, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method") } - return []byte(secret), nil // ← конвертируем string → []byte + return []byte(secret), nil }) + // token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { + // if t.Method != jwt.SigningMethodHS256 { + // return nil, errors.New("unexpected signing method: " + t.Method.Alg()) + // } + // return []byte(secret), nil // ← конвертируем string → []byte + // }) if err != nil { abortWithJSON(c, http.StatusUnauthorized, "invalid token: "+err.Error()) return @@ -55,20 +64,23 @@ func JWTAuthMiddleware(secret string) gin.HandlerFunc { } // 4. Извлекаем user_id из claims - claims, ok := token.Claims.(jwt.MapClaims) + claims, ok := token.Claims.(*auth.TokenClaims) if !ok { abortWithJSON(c, http.StatusUnauthorized, "invalid claims format") return } - userID, ok := claims["user_id"].(string) - if !ok || userID == "" { + if claims.UserID == "" { abortWithJSON(c, http.StatusUnauthorized, "user_id claim missing or invalid") return } + if claims.Type != "access" { + abortWithJSON(c, http.StatusUnauthorized, "token type is not access") + return + } // 5. Сохраняем в контексте - c.Set("user_id", userID) + c.Set("user_id", claims.UserID) // 6. Для oapi-codegen — кладём gin.Context в request context GinContextToContext(c)