瀏覽代碼

refactor auth

Paul 6 年之前
父節點
當前提交
baaafddd96
共有 5 個檔案被更改,包括 113 行新增52 行删除
  1. 16
    1
      Gopkg.lock
  2. 21
    28
      middleware/auth/auth.go
  3. 13
    23
      middleware/auth/optional_auth.go
  4. 14
    0
      middleware/auth/session.go
  5. 49
    0
      middleware/auth/session_redis.go

+ 16
- 1
Gopkg.lock 查看文件

@@ -30,6 +30,21 @@
30 30
   ]
31 31
   revision = "bf7803815b0baa22ff7a10457932882dfbf09925"
32 32
 
33
+[[projects]]
34
+  name = "github.com/go-redis/redis"
35
+  packages = [
36
+    ".",
37
+    "internal",
38
+    "internal/consistenthash",
39
+    "internal/hashtag",
40
+    "internal/pool",
41
+    "internal/proto",
42
+    "internal/singleflight",
43
+    "internal/util"
44
+  ]
45
+  revision = "83fb42932f6145ce52df09860384a4653d2d332a"
46
+  version = "v6.12.0"
47
+
33 48
 [[projects]]
34 49
   name = "github.com/golang/protobuf"
35 50
   packages = ["proto"]
@@ -102,6 +117,6 @@
102 117
 [solve-meta]
103 118
   analyzer-name = "dep"
104 119
   analyzer-version = 1
105
-  inputs-digest = "7fe5f8b83f3a0556f7574a9334d97c080f888f616754772851ed06f70be00a37"
120
+  inputs-digest = "f39a3af36dc118d9ce3e7647ae989cd30486ef9729d721fe106bbb7cfaf6e4cb"
106 121
   solver-name = "gps-cdcl"
107 122
   solver-version = 1

+ 21
- 28
middleware/auth/auth.go 查看文件

@@ -5,7 +5,6 @@ import (
5 5
 	"github.com/dgrijalva/jwt-go"
6 6
 	"github.com/gin-gonic/gin"
7 7
 	"net/http"
8
-	"time"
9 8
 )
10 9
 
11 10
 const (
@@ -15,29 +14,27 @@ const (
15 14
 	ctxRequestTokenExpired        = "expired"
16 15
 )
17 16
 
18
-func Auth(authKey string) gin.HandlerFunc {
17
+func Auth(authKey string, session Session) gin.HandlerFunc {
19 18
 	return func(ctx *gin.Context) {
20
-		var (
21
-			err error
22
-			tk  = ctx.Request.Header.Get(ctxRequestHeaderAuthorization)
23
-		)
19
+		var tokenFromCookie, tokenFromHeader string
24 20
 
25
-		if tk == "" {
26
-			tk, err = ctx.Cookie(ctxRequestCookieAuthorization)
27
-			if err != nil {
28
-				ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed"})
29
-				return
21
+		tokenFromCookie, err := ctx.Cookie(ctxRequestCookieAuthorization)
22
+		if err != nil {
23
+			if err == http.ErrNoCookie {
24
+				tokenFromHeader = ctx.Request.Header.Get(ctxRequestHeaderAuthorization)
30 25
 			}
26
+		}
31 27
 
32
-			tk = "Bearer " + tk
28
+		if tokenFromHeader == "" {
29
+			tokenFromHeader = "Bearer " + tokenFromCookie
33 30
 		}
34 31
 
35
-		if len(tk) < 8 {
32
+		if len(tokenFromHeader) < 8 {
36 33
 			ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed"})
37 34
 			return
38 35
 		}
39 36
 
40
-		token, err := jwt.Parse(tk[7:], func(token *jwt.Token) (interface{}, error) {
37
+		token, err := jwt.Parse(tokenFromHeader[7:], func(token *jwt.Token) (interface{}, error) {
41 38
 			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
42 39
 				return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
43 40
 			}
@@ -50,25 +47,21 @@ func Auth(authKey string) gin.HandlerFunc {
50 47
 			return
51 48
 		}
52 49
 
50
+		if !session.IsExistsJwtToken(token.Signature) {
51
+			ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed, token expired by server"})
52
+			return
53
+		}
54
+
53 55
 		if mapClaims, ok := token.Claims.(jwt.MapClaims); ok {
54 56
 			if expired, ok := mapClaims[ctxRequestTokenExpired].(float64); ok {
55
-				switch true {
56
-				case expired > 0:
57
-					if int64(expired) < time.Now().Unix() {
58
-						ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed, token timeout"})
59
-						return
60
-					}
57
+				if expired == 0 && tokenFromCookie == "" {
58
+					if session.DeleteJwtToken(token.Raw) {
59
+						ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed, token expired"})
61 60
 
62
-					// todo check expired from server
63
-				case expired == 0:
64
-					// Only cookie is exists, check token expired. app expired by itself call logout when app exit
65
-					if _, err := ctx.Cookie(ctxRequestCookieAuthorization); err != nil {
66
-						ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed, token timeout"})
67
-						return
61
+					} else {
62
+						ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed, delete server token failed"})
68 63
 					}
69 64
 
70
-				default:
71
-					ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed, token timeout"})
72 65
 					return
73 66
 				}
74 67
 

+ 13
- 23
middleware/auth/optional_auth.go 查看文件

@@ -2,33 +2,32 @@ package auth
2 2
 
3 3
 import (
4 4
 	"fmt"
5
-	"time"
6
-
7 5
 	"github.com/dgrijalva/jwt-go"
8 6
 	"github.com/gin-gonic/gin"
7
+	"net/http"
9 8
 )
10 9
 
11 10
 func OptionalAuth(authKey string) gin.HandlerFunc {
12 11
 	return func(ctx *gin.Context) {
13
-		var (
14
-			err error
15
-			tk  = ctx.Request.Header.Get(ctxRequestHeaderAuthorization)
16
-		)
17
-
18
-		if tk == "" {
19
-			tk, err = ctx.Cookie(ctxRequestCookieAuthorization)
20
-			if err != nil {
21
-				return
12
+		var tokenFromCookie, tokenFromHeader string
13
+
14
+		tokenFromCookie, err := ctx.Cookie(ctxRequestCookieAuthorization)
15
+		if err != nil {
16
+			if err == http.ErrNoCookie {
17
+				tokenFromHeader = ctx.Request.Header.Get(ctxRequestHeaderAuthorization)
22 18
 			}
19
+		}
23 20
 
24
-			tk = "Bearer " + tk
21
+		if tokenFromHeader == "" {
22
+			tokenFromHeader = "Bearer " + tokenFromCookie
25 23
 		}
26 24
 
27
-		if len(tk) < 8 {
25
+		if len(tokenFromHeader) < 8 {
26
+			ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"msg": "auth failed"})
28 27
 			return
29 28
 		}
30 29
 
31
-		token, err := jwt.Parse(tk[7:], func(token *jwt.Token) (interface{}, error) {
30
+		token, err := jwt.Parse(tokenFromHeader[7:], func(token *jwt.Token) (interface{}, error) {
32 31
 			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
33 32
 				return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
34 33
 			}
@@ -41,15 +40,6 @@ func OptionalAuth(authKey string) gin.HandlerFunc {
41 40
 		}
42 41
 
43 42
 		if mapClaims, ok := token.Claims.(jwt.MapClaims); ok {
44
-			if expired, ok := mapClaims[ctxRequestTokenExpired].(float64); ok {
45
-				if int64(expired) < time.Now().Unix() {
46
-					// Only cookie is blank value, check token expired
47
-					if _, err := ctx.Cookie(ctxRequestCookieAuthorization); err != nil {
48
-						return
49
-					}
50
-				}
51
-			}
52
-
53 43
 			if uid, ok := mapClaims[CtxRequestHeaderUserId].(float64); ok {
54 44
 				ctx.Set(CtxRequestHeaderUserId, int64(uid))
55 45
 			}

+ 14
- 0
middleware/auth/session.go 查看文件

@@ -0,0 +1,14 @@
1
+package auth
2
+
3
+import "time"
4
+
5
+type Session interface {
6
+	// StoreJwtToken store jwt token is redis, make it expired for feature
7
+	StoreJwtToken(key string, value string, timeout time.Duration) error
8
+
9
+	// IsExistsJwtToken judge whether token is invalid
10
+	IsExistsJwtToken(key string) bool
11
+
12
+	// DeleteJwtToken delete jwt token
13
+	DeleteJwtToken(key string) bool
14
+}

+ 49
- 0
middleware/auth/session_redis.go 查看文件

@@ -0,0 +1,49 @@
1
+package auth
2
+
3
+import (
4
+	"github.com/go-redis/redis"
5
+	"sync"
6
+	"time"
7
+)
8
+
9
+type RedisSessionStore struct {
10
+	once   *sync.Once
11
+	client *redis.Client
12
+}
13
+
14
+func NewRedisSessionStore(address string, password string) *RedisSessionStore {
15
+	session := &RedisSessionStore{}
16
+	session.once.Do(func() {
17
+		session.client = redis.NewClient(&redis.Options{
18
+			Addr:     address,
19
+			Password: password,
20
+		})
21
+
22
+		if err := session.client.Ping().Err(); err != nil {
23
+			panic(err)
24
+		}
25
+	})
26
+
27
+	return session
28
+}
29
+
30
+func (rss *RedisSessionStore) StoreJwtToken(key string, value string, timeout time.Duration) error {
31
+	return rss.client.Set(key, value, timeout).Err()
32
+}
33
+
34
+func (rss *RedisSessionStore) IsExistsJwtToken(key string) bool {
35
+	if v := rss.client.Exists(key).Val(); v != 1 {
36
+		return false
37
+	}
38
+
39
+	return true
40
+}
41
+
42
+// DeleteJwtToken delete jwt token
43
+func (rss *RedisSessionStore) DeleteJwtToken(key string) bool {
44
+	if v := rss.client.Del(key).Val(); v != 1 {
45
+		return false
46
+	}
47
+
48
+	return true
49
+}