Fix: rotate access token and add test
This commit is contained in:
@@ -44,7 +44,7 @@ func (self *Handlers) PostLogin(
|
||||
}
|
||||
}
|
||||
|
||||
session, err := self.db.GetSession(ctx, input.Token)
|
||||
session, err := self.db.GetSessionByLoginToken(ctx, input.Token)
|
||||
if err != nil {
|
||||
return middlewares.HTTPError{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewBunDatabase(db *bun.DB) *BunDatabase {
|
||||
return &BunDatabase{db: db}
|
||||
}
|
||||
|
||||
func (self *BunDatabase) GetSession(
|
||||
func (self *BunDatabase) GetSessionByLoginToken(
|
||||
ctx context.Context,
|
||||
loginToken string,
|
||||
) (models.Session, error) {
|
||||
@@ -35,6 +35,23 @@ func (self *BunDatabase) GetSession(
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (self *BunDatabase) GetSessionByUserId(
|
||||
ctx context.Context,
|
||||
userId string,
|
||||
) (models.Session, error) {
|
||||
ret := models.Session{
|
||||
UserId: userId,
|
||||
}
|
||||
err := self.db.NewSelect().
|
||||
Model(&ret).
|
||||
Where("user_id = ?", userId).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return models.Session{}, err
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (self *BunDatabase) UpdateRefreshToken(
|
||||
ctx context.Context,
|
||||
userId string,
|
||||
@@ -88,6 +105,7 @@ func (self *BunDatabase) UpsertLoginToken(
|
||||
session := models.Session{
|
||||
UserId: userId,
|
||||
LoginToken: token,
|
||||
IsValid: true,
|
||||
}
|
||||
_, err = self.db.NewInsert().
|
||||
Model(&session).
|
||||
|
||||
@@ -7,11 +7,16 @@ import (
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
GetSession(
|
||||
GetSessionByLoginToken(
|
||||
ctx context.Context,
|
||||
loginToken string,
|
||||
) (models.Session, error)
|
||||
|
||||
GetSessionByUserId(
|
||||
ctx context.Context,
|
||||
userId string,
|
||||
) (models.Session, error)
|
||||
|
||||
UpdateRefreshToken(
|
||||
ctx context.Context,
|
||||
userId string,
|
||||
|
||||
@@ -30,7 +30,8 @@ func refreshAccessToken(
|
||||
return "", types.ContextNotExistError
|
||||
}
|
||||
|
||||
session, err := db.GetSession(ctx, refreshTokenClaim.UserId)
|
||||
session, err := db.GetSessionByUserId(ctx,
|
||||
refreshTokenClaim.UserId)
|
||||
if err != nil {
|
||||
tracing.Logger.Ctx(ctx).
|
||||
Warn("session not exist", zap.Error(err))
|
||||
@@ -68,7 +69,7 @@ func (self *Handlers) CheckAccessToken(
|
||||
if err != nil {
|
||||
return HTTPError{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
Message: "access token refresh failed",
|
||||
Message: "failed to refresh access token",
|
||||
OriginError: err,
|
||||
}
|
||||
}
|
||||
@@ -86,7 +87,7 @@ func (self *Handlers) CheckAccessToken(
|
||||
if err != nil {
|
||||
return HTTPError{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
Message: "access token refresh failed",
|
||||
Message: "failed to refresh access token",
|
||||
OriginError: err,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"time"
|
||||
|
||||
"gitea.konchin.com/go2025/backend/models"
|
||||
"gitea.konchin.com/go2025/backend/tracing"
|
||||
"gitea.konchin.com/go2025/backend/types"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/uptrace/bunrouter"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (self *Handlers) CheckRefreshToken(
|
||||
@@ -46,6 +48,10 @@ func (self *Handlers) CheckRefreshToken(
|
||||
}
|
||||
}
|
||||
|
||||
tracing.Logger.Ctx(ctx).
|
||||
Debug("where is my fucking UserId",
|
||||
zap.String("userId", claim.UserId))
|
||||
|
||||
// check time and refresh
|
||||
timeLeft := claim.ExpiresAt.Time.Sub(time.Now()) / time.Second
|
||||
if int64(timeLeft) < viper.GetInt64("refresh-token-timeout")/2 {
|
||||
|
||||
@@ -67,6 +67,7 @@ func Test_01_Login(t *testing.T) {
|
||||
if len(cookie.Value) == 0 {
|
||||
t.Fatal("empty refresh token")
|
||||
}
|
||||
client.SetCookie(cookie)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
15
tests/02_getAliases_test.go
Normal file
15
tests/02_getAliases_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_02_GetImages(t *testing.T) {
|
||||
resp, err := client.R().
|
||||
Get("http://localhost:8080/api/aliases")
|
||||
if err != nil || resp.StatusCode() != http.StatusOK {
|
||||
t.Logf("%+v", resp)
|
||||
t.Fatal("failed to fetch aliases")
|
||||
}
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_02_GetImages(t *testing.T) {
|
||||
}
|
||||
Reference in New Issue
Block a user