315 lines
9.6 KiB
Go
315 lines
9.6 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
auth "nyanimedb/auth"
|
|
sqlc "nyanimedb/sql"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/alexedwards/argon2id"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type Server struct {
|
|
db *sqlc.Queries
|
|
JwtPrivateKey string
|
|
}
|
|
|
|
func NewServer(db *sqlc.Queries, JwtPrivatekey string) Server {
|
|
return Server{db: db, JwtPrivateKey: JwtPrivatekey}
|
|
}
|
|
|
|
func parseInt64(s string) (int32, error) {
|
|
i, err := strconv.ParseInt(s, 10, 64)
|
|
return int32(i), err
|
|
}
|
|
|
|
func HashPassword(password string) (string, error) {
|
|
params := &argon2id.Params{
|
|
Memory: 64 * 1024,
|
|
Iterations: 3,
|
|
Parallelism: 2,
|
|
SaltLength: 16,
|
|
KeyLength: 32,
|
|
}
|
|
|
|
return argon2id.CreateHash(password, params)
|
|
}
|
|
|
|
func CheckPassword(password, hash string) (bool, error) {
|
|
return argon2id.ComparePasswordAndHash(password, hash)
|
|
}
|
|
|
|
func (s *Server) generateImpersonationToken(userID string, impersonatedBy string) (string, error) {
|
|
now := time.Now()
|
|
claims := auth.TokenClaims{
|
|
ImpID: &impersonatedBy,
|
|
Type: "access",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)),
|
|
ID: generateJTI(),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return token.SignedString([]byte(s.JwtPrivateKey))
|
|
}
|
|
|
|
func (s *Server) generateTokens(userID string) (accessToken string, refreshToken string, csrfToken string, err error) {
|
|
now := time.Now()
|
|
|
|
// Access token (15 мин)
|
|
accessClaims := auth.TokenClaims{
|
|
Type: "access",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)),
|
|
ID: generateJTI(),
|
|
},
|
|
}
|
|
at := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
|
accessToken, err = at.SignedString([]byte(s.JwtPrivateKey))
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
|
|
// Refresh token (7 дней)
|
|
refreshClaims := auth.TokenClaims{
|
|
Type: "refresh",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(7 * 24 * time.Hour)),
|
|
ID: generateJTI(),
|
|
},
|
|
}
|
|
rt := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
|
refreshToken, err = rt.SignedString([]byte(s.JwtPrivateKey))
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
|
|
// CSRF token
|
|
csrfBytes := make([]byte, 32)
|
|
_, err = rand.Read(csrfBytes)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
csrfToken = base64.RawURLEncoding.EncodeToString(csrfBytes)
|
|
|
|
return accessToken, refreshToken, csrfToken, nil
|
|
}
|
|
|
|
func (s Server) PostSignUp(ctx context.Context, req auth.PostSignUpRequestObject) (auth.PostSignUpResponseObject, error) {
|
|
passhash, err := HashPassword(req.Body.Pass)
|
|
if err != nil {
|
|
log.Errorf("failed to hash password: %v", err)
|
|
// TODO: return 500
|
|
}
|
|
|
|
user_id, err := s.db.CreateNewUser(context.Background(), sqlc.CreateNewUserParams{
|
|
Passhash: passhash,
|
|
Nickname: req.Body.Nickname,
|
|
})
|
|
if err != nil {
|
|
log.Errorf("failed to create user %s: %v", req.Body.Nickname, err)
|
|
// TODO: check err and retyrn 400/500
|
|
}
|
|
|
|
return auth.PostSignUp200JSONResponse{
|
|
UserId: user_id,
|
|
}, nil
|
|
}
|
|
|
|
func (s Server) PostSignIn(ctx context.Context, req auth.PostSignInRequestObject) (auth.PostSignInResponseObject, error) {
|
|
ginCtx, ok := ctx.Value(gin.ContextKey).(*gin.Context)
|
|
if !ok {
|
|
log.Print("failed to get gin context")
|
|
// TODO: change to 500
|
|
return auth.PostSignIn200JSONResponse{}, fmt.Errorf("failed to get gin.Context from context.Context")
|
|
}
|
|
|
|
user, err := s.db.GetUserByNickname(context.Background(), req.Body.Nickname)
|
|
if err != nil {
|
|
log.Errorf("failed to get user by nickname %s: %v", req.Body.Nickname, err)
|
|
// TODO: return 400/500
|
|
}
|
|
|
|
ok, err = CheckPassword(req.Body.Pass, user.Passhash)
|
|
if err != nil {
|
|
log.Errorf("failed to check password for user %s: %v", req.Body.Nickname, err)
|
|
// TODO: return 500
|
|
}
|
|
if !ok {
|
|
return auth.PostSignIn401Response{}, nil
|
|
}
|
|
|
|
accessToken, refreshToken, csrfToken, err := s.generateTokens(fmt.Sprintf("%d", user.ID))
|
|
if err != nil {
|
|
log.Errorf("failed to generate tokens for user %s: %v", req.Body.Nickname, err)
|
|
// TODO: return 500
|
|
}
|
|
|
|
// TODO: check cookie settings carefully
|
|
ginCtx.SetSameSite(http.SameSiteStrictMode)
|
|
ginCtx.SetCookie("access_token", accessToken, 900, "/api", "", false, true)
|
|
ginCtx.SetCookie("refresh_token", refreshToken, 1209600, "/auth", "", false, true)
|
|
ginCtx.SetCookie("xsrf_token", csrfToken, 1209600, "/", "", false, false)
|
|
|
|
result := auth.PostSignIn200JSONResponse{
|
|
UserId: user.ID,
|
|
UserName: user.Nickname,
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s Server) GetImpersonationToken(ctx context.Context, req auth.GetImpersonationTokenRequestObject) (auth.GetImpersonationTokenResponseObject, error) {
|
|
ginCtx, ok := ctx.Value(gin.ContextKey).(*gin.Context)
|
|
if !ok {
|
|
log.Print("failed to get gin context")
|
|
// TODO: change to 500
|
|
return auth.GetImpersonationToken200JSONResponse{}, fmt.Errorf("failed to get gin.Context from context.Context")
|
|
}
|
|
|
|
token, err := ExtractBearerToken(ginCtx.Request.Header.Get("Authorization"))
|
|
if err != nil {
|
|
// TODO: return 500
|
|
log.Errorf("failed to extract bearer token: %v", err)
|
|
return auth.GetImpersonationToken401Response{}, err
|
|
}
|
|
log.Printf("got auth token: %s", token)
|
|
|
|
ext_service, err := s.db.GetExternalServiceByToken(context.Background(), &token)
|
|
if err != nil {
|
|
log.Errorf("failed to get external service by token: %v", err)
|
|
return auth.GetImpersonationToken401Response{}, err
|
|
// TODO: check err and retyrn 400/500
|
|
}
|
|
|
|
var user_id string = ""
|
|
|
|
if req.Body.ExternalId != nil {
|
|
user, err := s.db.GetUserByExternalServiceId(context.Background(), sqlc.GetUserByExternalServiceIdParams{
|
|
ExternalID: fmt.Sprintf("%d", *req.Body.ExternalId),
|
|
ServiceID: ext_service.ID,
|
|
})
|
|
if err != nil {
|
|
log.Errorf("failed to get user by external user id: %v", err)
|
|
return auth.GetImpersonationToken401Response{}, err
|
|
// TODO: check err and retyrn 400/500
|
|
}
|
|
|
|
user_id = fmt.Sprintf("%d", user.ID)
|
|
}
|
|
|
|
if req.Body.UserId != nil {
|
|
// TODO: check user existence
|
|
if user_id != "" && user_id != fmt.Sprintf("%d", *req.Body.UserId) {
|
|
log.Error("user_id and external_d are incorrect")
|
|
// TODO: 405
|
|
return auth.GetImpersonationToken401Response{}, nil
|
|
} else {
|
|
user_id = fmt.Sprintf("%d", *req.Body.UserId)
|
|
}
|
|
}
|
|
|
|
accessToken, err := s.generateImpersonationToken(user_id, fmt.Sprintf("%d", ext_service.ID))
|
|
if err != nil {
|
|
log.Errorf("failed to generate impersonation token: %v", err)
|
|
return auth.GetImpersonationToken401Response{}, err
|
|
// TODO: check err and retyrn 400/500
|
|
}
|
|
|
|
return auth.GetImpersonationToken200JSONResponse{AccessToken: accessToken}, nil
|
|
}
|
|
|
|
func (s Server) RefreshTokens(ctx context.Context, req auth.RefreshTokensRequestObject) (auth.RefreshTokensResponseObject, error) {
|
|
ginCtx, ok := ctx.Value(gin.ContextKey).(*gin.Context)
|
|
if !ok {
|
|
log.Print("failed to get gin context")
|
|
return auth.RefreshTokens500Response{}, fmt.Errorf("failed to get gin.Context from context.Context")
|
|
}
|
|
|
|
rtCookie, err := ginCtx.Request.Cookie("refresh_token")
|
|
if err != nil {
|
|
log.Print("failed to get refresh_token cookie")
|
|
return auth.RefreshTokens400Response{}, fmt.Errorf("failed to get refresh_token cookie")
|
|
}
|
|
|
|
refreshToken := rtCookie.Value
|
|
|
|
token, err := jwt.ParseWithClaims(refreshToken, &auth.TokenClaims{}, func(t *jwt.Token) (interface{}, error) {
|
|
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method")
|
|
}
|
|
return []byte(s.JwtPrivateKey), nil
|
|
})
|
|
if err != nil || !token.Valid {
|
|
log.Print("invalid refresh token")
|
|
return auth.RefreshTokens401Response{}, nil
|
|
}
|
|
|
|
claims, ok := token.Claims.(*auth.TokenClaims)
|
|
if !ok || claims.Subject == "" {
|
|
log.Print("invalid refresh token claims")
|
|
return auth.RefreshTokens401Response{}, nil
|
|
}
|
|
if claims.Type != "refresh" {
|
|
log.Errorf("token is not a refresh token")
|
|
return auth.RefreshTokens401Response{}, nil
|
|
}
|
|
|
|
accessToken, refreshToken, csrfToken, err := s.generateTokens(claims.Subject)
|
|
if err != nil {
|
|
log.Errorf("failed to generate tokens for user %s: %v", claims.Subject, err)
|
|
return auth.RefreshTokens500Response{}, nil
|
|
}
|
|
|
|
// TODO: check cookie settings carefully
|
|
ginCtx.SetSameSite(http.SameSiteStrictMode)
|
|
ginCtx.SetCookie("access_token", accessToken, 900, "/api", "", false, true)
|
|
ginCtx.SetCookie("refresh_token", refreshToken, 1209600, "/auth", "", false, true)
|
|
ginCtx.SetCookie("xsrf_token", csrfToken, 1209600, "/", "", false, false)
|
|
|
|
return auth.RefreshTokens200Response{}, nil
|
|
}
|
|
|
|
func (s Server) Logout(ctx context.Context, req auth.LogoutRequestObject) (auth.LogoutResponseObject, error) {
|
|
// TODO: get current tokens and add them to block list
|
|
ginCtx, ok := ctx.Value(gin.ContextKey).(*gin.Context)
|
|
if !ok {
|
|
log.Print("failed to get gin context")
|
|
return auth.Logout500Response{}, fmt.Errorf("failed to get gin.Context from context.Context")
|
|
}
|
|
|
|
// Delete cookies by setting MaxAge negative
|
|
ginCtx.SetCookie("access_token", "", -1, "/api", "", true, true)
|
|
ginCtx.SetCookie("refresh_token", "", -1, "/auth", "", true, true)
|
|
ginCtx.SetCookie("xsrf_token", "", -1, "/", "", false, false)
|
|
|
|
return auth.Logout200Response{}, nil
|
|
}
|
|
|
|
func ExtractBearerToken(header string) (string, error) {
|
|
const prefix = "Bearer "
|
|
if len(header) <= len(prefix) || header[:len(prefix)] != prefix {
|
|
return "", fmt.Errorf("invalid bearer token format")
|
|
}
|
|
return header[len(prefix):], nil
|
|
}
|
|
|
|
func generateJTI() string {
|
|
b := make([]byte, 16)
|
|
_, _ = rand.Read(b)
|
|
return base64.RawURLEncoding.EncodeToString(b)
|
|
}
|