另客网go项目公用的代码库

websocket.go 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package linker
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "runtime"
  7. "time"
  8. "errors"
  9. "sync"
  10. "github.com/gorilla/websocket"
  11. "github.com/wpajqz/linker/utils/convert"
  12. )
  13. func (s *Server) handleWebSocketConnection(conn *websocket.Conn) error {
  14. wsn := &webSocketConn{mutex: sync.Mutex{}, conn: conn}
  15. var ctx Context = &ContextWebsocket{Conn: wsn}
  16. if s.constructHandler != nil {
  17. s.constructHandler.Handle(ctx)
  18. }
  19. defer func() {
  20. if s.destructHandler != nil {
  21. s.destructHandler.Handle(ctx)
  22. }
  23. conn.Close()
  24. }()
  25. var (
  26. bType = make([]byte, 4)
  27. bSequence = make([]byte, 8)
  28. bHeaderLength = make([]byte, 4)
  29. bBodyLength = make([]byte, 4)
  30. sequence int64
  31. headerLength uint32
  32. bodyLength uint32
  33. )
  34. conn.SetReadLimit(MaxPayload)
  35. for {
  36. if s.config.Timeout != 0 {
  37. conn.SetReadDeadline(time.Now().Add(s.config.Timeout))
  38. }
  39. _, r, err := conn.NextReader()
  40. if err != nil {
  41. return err
  42. }
  43. if _, err := io.ReadFull(r, bType); err != nil {
  44. return err
  45. }
  46. if _, err := io.ReadFull(r, bSequence); err != nil {
  47. return err
  48. }
  49. if _, err := io.ReadFull(r, bHeaderLength); err != nil {
  50. return err
  51. }
  52. if _, err := io.ReadFull(r, bBodyLength); err != nil {
  53. return err
  54. }
  55. sequence = convert.BytesToInt64(bSequence)
  56. headerLength = convert.BytesToUint32(bHeaderLength)
  57. bodyLength = convert.BytesToUint32(bBodyLength)
  58. pacLen := headerLength + bodyLength + uint32(20)
  59. if pacLen > s.config.MaxPayload {
  60. _, file, line, _ := runtime.Caller(1)
  61. return SystemError{time.Now(), file, line, "packet larger than MaxPayload"}
  62. }
  63. header := make([]byte, headerLength)
  64. if _, err := io.ReadFull(r, header); err != nil {
  65. return err
  66. }
  67. body := make([]byte, bodyLength)
  68. if _, err := io.ReadFull(r, body); err != nil {
  69. return err
  70. }
  71. rp, err := NewPacket(convert.BytesToUint32(bType), sequence, header, body, s.config.PluginForPacketReceiver)
  72. if err != nil {
  73. return err
  74. }
  75. ctx = NewContextWebsocket(wsn, rp.Operator, rp.Sequence, rp.Header, rp.Body, s.config)
  76. go s.handleWebSocketPacket(ctx, conn, rp)
  77. }
  78. }
  79. func (s *Server) handleWebSocketPacket(ctx Context, conn *websocket.Conn, rp Packet) {
  80. defer func() {
  81. if r := recover(); r != nil {
  82. if s.errorHandler != nil {
  83. buf := make([]byte, 1<<12)
  84. n := runtime.Stack(buf, false)
  85. s.errorHandler(errors.New(string(buf[:n])))
  86. }
  87. }
  88. }()
  89. if rp.Operator == OPERATOR_HEARTBEAT && s.pingHandler != nil {
  90. s.pingHandler.Handle(ctx)
  91. ctx.Success(nil)
  92. }
  93. handler, ok := s.router.handlerContainer[rp.Operator]
  94. if !ok {
  95. ctx.Error(StatusInternalServerError, "server don't register your request.")
  96. }
  97. if rm, ok := s.router.routerMiddleware[rp.Operator]; ok {
  98. for _, v := range rm {
  99. ctx = v.Handle(ctx)
  100. }
  101. }
  102. for _, v := range s.router.middleware {
  103. ctx = v.Handle(ctx)
  104. if tm, ok := v.(TerminateMiddleware); ok {
  105. tm.Terminate(ctx)
  106. }
  107. }
  108. handler.Handle(ctx)
  109. ctx.Success(nil) // If it don't call the function of Success or Error, deal it by default
  110. }
  111. // 开始运行webocket服务
  112. func (s *Server) RunWebSocket(address string) error {
  113. http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  114. var upgrade = websocket.Upgrader{
  115. HandshakeTimeout: s.config.Timeout,
  116. ReadBufferSize: s.config.ReadBufferSize,
  117. WriteBufferSize: s.config.WriteBufferSize,
  118. EnableCompression: true,
  119. CheckOrigin: func(r *http.Request) bool {
  120. return true
  121. },
  122. }
  123. conn, err := upgrade.Upgrade(w, r, nil)
  124. if err != nil {
  125. w.Write([]byte(err.Error()))
  126. return
  127. }
  128. go s.handleWebSocketConnection(conn)
  129. })
  130. fmt.Printf("websocket server running on %s\n", address)
  131. return http.ListenAndServe(address, nil)
  132. }