feat: xsrf_token set

This commit is contained in:
nihonium 2025-12-04 06:29:20 +03:00
parent 4dd60f3b19
commit ef871833c5
Signed by: nihonium
GPG key ID: 0251623741027CFC
5 changed files with 117 additions and 20 deletions

View file

@ -62,6 +62,8 @@ services:
environment: environment:
LOG_LEVEL: ${LOG_LEVEL} LOG_LEVEL: ${LOG_LEVEL}
DATABASE_URL: ${DATABASE_URL} DATABASE_URL: ${DATABASE_URL}
SERVICE_ADDRESS: ${SERVICE_ADDRESS}
JWT_PRIVATE_KEY: ${JWT_PRIVATE_KEY}
ports: ports:
- "8082:8082" - "8082:8082"
depends_on: depends_on:

View file

@ -2,6 +2,8 @@ package handlers
import ( import (
"context" "context"
"crypto/rand"
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
auth "nyanimedb/auth" auth "nyanimedb/auth"
@ -15,15 +17,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var accessSecret = []byte("my_access_secret_key")
var refreshSecret = []byte("my_refresh_secret_key")
type Server struct { type Server struct {
db *sqlc.Queries db *sqlc.Queries
JwtPrivateKey string
} }
func NewServer(db *sqlc.Queries) Server { func NewServer(db *sqlc.Queries, JwtPrivatekey string) Server {
return Server{db: db} return Server{db: db, JwtPrivateKey: JwtPrivatekey}
} }
func parseInt64(s string) (int32, error) { func parseInt64(s string) (int32, error) {
@ -47,15 +47,15 @@ func CheckPassword(password, hash string) (bool, error) {
return argon2id.ComparePasswordAndHash(password, hash) return argon2id.ComparePasswordAndHash(password, hash)
} }
func generateTokens(userID string) (accessToken string, refreshToken string, err error) { func (s Server) generateTokens(userID string) (accessToken string, refreshToken string, csrfToken string, err error) {
accessClaims := jwt.MapClaims{ accessClaims := jwt.MapClaims{
"user_id": userID, "user_id": userID,
"exp": time.Now().Add(15 * time.Minute).Unix(), "exp": time.Now().Add(15 * time.Minute).Unix(),
} }
at := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims) at := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessToken, err = at.SignedString(accessSecret) accessToken, err = at.SignedString(s.JwtPrivateKey)
if err != nil { if err != nil {
return "", "", err return "", "", "", err
} }
refreshClaims := jwt.MapClaims{ refreshClaims := jwt.MapClaims{
@ -63,12 +63,19 @@ func generateTokens(userID string) (accessToken string, refreshToken string, err
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(), "exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
} }
rt := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims) rt := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshToken, err = rt.SignedString(refreshSecret) refreshToken, err = rt.SignedString(s.JwtPrivateKey)
if err != nil { if err != nil {
return "", "", err return "", "", "", err
} }
return accessToken, refreshToken, nil csrfBytes := make([]byte, 32)
_, err = rand.Read(csrfBytes)
if err != nil {
return "", "", "", err
}
csrfToken = base64.RawURLEncoding.EncodeToString(csrfBytes)
return accessToken, refreshToken, csrfToken, nil
} }
func (s Server) PostAuthSignUp(ctx context.Context, req auth.PostAuthSignUpRequestObject) (auth.PostAuthSignUpResponseObject, error) { func (s Server) PostAuthSignUp(ctx context.Context, req auth.PostAuthSignUpRequestObject) (auth.PostAuthSignUpResponseObject, error) {
@ -118,7 +125,7 @@ func (s Server) PostAuthSignIn(ctx context.Context, req auth.PostAuthSignInReque
}, nil }, nil
} }
accessToken, refreshToken, err := generateTokens(req.Body.Nickname) accessToken, refreshToken, csrfToken, err := s.generateTokens(req.Body.Nickname)
if err != nil { if err != nil {
log.Errorf("failed to generate tokens for user %s: %v", req.Body.Nickname, err) log.Errorf("failed to generate tokens for user %s: %v", req.Body.Nickname, err)
// TODO: return 500 // TODO: return 500
@ -126,8 +133,9 @@ func (s Server) PostAuthSignIn(ctx context.Context, req auth.PostAuthSignInReque
// TODO: check cookie settings carefully // TODO: check cookie settings carefully
ginCtx.SetSameSite(http.SameSiteStrictMode) ginCtx.SetSameSite(http.SameSiteStrictMode)
ginCtx.SetCookie("access_token", accessToken, 604800, "/auth", "", false, true) ginCtx.SetCookie("access_token", accessToken, 900, "/api", "", false, true)
ginCtx.SetCookie("refresh_token", refreshToken, 604800, "/api", "", false, true) ginCtx.SetCookie("refresh_token", refreshToken, 1209600, "/auth", "", false, true)
ginCtx.SetCookie("xsrf_token", csrfToken, 1209600, "/api", "", false, false)
result := auth.PostAuthSignIn200JSONResponse{ result := auth.PostAuthSignIn200JSONResponse{
UserId: user.ID, UserId: user.ID,

33
modules/auth/helpers.go Normal file
View file

@ -0,0 +1,33 @@
package main
import (
"fmt"
"reflect"
)
func setField(obj interface{}, name string, value interface{}) error {
v := reflect.ValueOf(obj)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("expected pointer to a struct")
}
v = v.Elem()
field := v.FieldByName(name)
if !field.IsValid() {
return fmt.Errorf("no such field: %s", name)
}
if !field.CanSet() {
return fmt.Errorf("cannot set field: %s", name)
}
val := reflect.ValueOf(value)
if field.Type() != val.Type() {
return fmt.Errorf("provided value type (%s) doesn't match field type (%s)", val.Type(), field.Type())
}
field.Set(val)
return nil
}

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"reflect"
"time" "time"
auth "nyanimedb/auth" auth "nyanimedb/auth"
@ -13,12 +14,24 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/pelletier/go-toml/v2"
log "github.com/sirupsen/logrus"
) )
var AppConfig Config var AppConfig Config
func main() { func main() {
// TODO: env args if len(os.Args) != 2 {
AppConfig.Mode = "env"
} else {
AppConfig.Mode = "argv"
}
err := InitConfig()
if err != nil {
log.Fatalf("Failed to init config: %v\n", err)
}
r := gin.Default() r := gin.Default()
pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
@ -29,10 +42,10 @@ func main() {
var queries *sqlc.Queries = sqlc.New(pool) var queries *sqlc.Queries = sqlc.New(pool)
server := handlers.NewServer(queries) server := handlers.NewServer(queries, AppConfig.JwtPrivateKey)
r.Use(cors.New(cors.Config{ r.Use(cors.New(cors.Config{
AllowOrigins: []string{"*"}, // allow all origins, change to specific domains in production AllowOrigins: []string{AppConfig.ServiceAddress},
AllowMethods: []string{"GET", "POST", "PUT", "DELETE"}, AllowMethods: []string{"GET", "POST", "PUT", "DELETE"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept"}, AllowHeaders: []string{"Origin", "Content-Type", "Accept"},
ExposeHeaders: []string{"Content-Length"}, ExposeHeaders: []string{"Content-Length"},
@ -47,3 +60,41 @@ func main() {
r.Run(":8082") r.Run(":8082")
} }
func InitConfig() error {
if AppConfig.Mode == "argv" {
content, err := os.ReadFile(os.Args[1])
if err != nil {
return err
}
toml.Unmarshal(content, &AppConfig)
fmt.Printf("%+v\n", AppConfig)
return nil
} else if AppConfig.Mode == "env" {
f := reflect.ValueOf(AppConfig)
for i := 0; i < f.NumField(); i++ {
field := f.Type().Field(i)
tag := field.Tag
env_var := tag.Get("env")
fmt.Printf("Field: %v.\nEnvironment variable: %v.\n", field.Name, env_var)
if env_var != "" {
env_value, exists := os.LookupEnv(env_var)
if !exists {
return fmt.Errorf("there is no env variable %s", env_var)
}
err := setField(&AppConfig, field.Name, env_value)
if err != nil {
return fmt.Errorf("failed to set config field %s: %v", field.Name, err)
}
}
}
return nil
} else {
return fmt.Errorf("incorrect config mode")
}
}

View file

@ -1,6 +1,9 @@
package main package main
type Config struct { type Config struct {
JwtPrivateKey string Mode string
LogLevel string `toml:"LogLevel" env:"LOG_LEVEL"` ServiceAddress string `toml:"ServiceAddress" env:"SERVICE_ADDRESS"`
DdUrl string `toml:"DbUrl" env:"DATABASE_URL"`
JwtPrivateKey string `toml:"JwtPrivateKey" env:"JWT_PRIVATE_KEY"`
LogLevel string `toml:"LogLevel" env:"LOG_LEVEL"`
} }