另客网go项目公用的代码库

tcp.go 3.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package linker
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net"
  7. "runtime"
  8. "time"
  9. "github.com/wpajqz/linker/utils/convert"
  10. )
  11. func (s *Server) handleTCPConnection(conn *net.TCPConn) error {
  12. var ctx Context = &ContextTcp{Conn: conn}
  13. if s.constructHandler != nil {
  14. s.constructHandler.Handle(ctx)
  15. }
  16. defer func() {
  17. if s.destructHandler != nil {
  18. s.destructHandler.Handle(ctx)
  19. }
  20. conn.Close()
  21. }()
  22. if s.config.ReadBufferSize > 0 {
  23. conn.SetReadBuffer(s.config.ReadBufferSize)
  24. }
  25. if s.config.WriteBufferSize > 0 {
  26. conn.SetWriteBuffer(s.config.WriteBufferSize)
  27. }
  28. var (
  29. bType = make([]byte, 4)
  30. bSequence = make([]byte, 8)
  31. bHeaderLength = make([]byte, 4)
  32. bBodyLength = make([]byte, 4)
  33. sequence int64
  34. headerLength uint32
  35. bodyLength uint32
  36. )
  37. for {
  38. if s.config.Timeout != 0 {
  39. conn.SetDeadline(time.Now().Add(s.config.Timeout))
  40. }
  41. if _, err := io.ReadFull(conn, bType); err != nil {
  42. return err
  43. }
  44. if _, err := io.ReadFull(conn, bSequence); err != nil {
  45. return err
  46. }
  47. if _, err := io.ReadFull(conn, bHeaderLength); err != nil {
  48. return err
  49. }
  50. if _, err := io.ReadFull(conn, bBodyLength); err != nil {
  51. return err
  52. }
  53. sequence = convert.BytesToInt64(bSequence)
  54. headerLength = convert.BytesToUint32(bHeaderLength)
  55. bodyLength = convert.BytesToUint32(bBodyLength)
  56. pacLen := headerLength + bodyLength + uint32(20)
  57. if pacLen > s.config.MaxPayload {
  58. _, file, line, _ := runtime.Caller(1)
  59. return SystemError{time.Now(), file, line, "packet larger than MaxPayload"}
  60. }
  61. header := make([]byte, headerLength)
  62. if _, err := io.ReadFull(conn, header); err != nil {
  63. return err
  64. }
  65. body := make([]byte, bodyLength)
  66. if _, err := io.ReadFull(conn, body); err != nil {
  67. return err
  68. }
  69. rp, err := NewPacket(convert.BytesToUint32(bType), sequence, header, body, s.config.PluginForPacketReceiver)
  70. if err != nil {
  71. return err
  72. }
  73. ctx = NewContextTcp(conn, rp.Operator, rp.Sequence, rp.Header, rp.Body, s.config)
  74. go s.handleTCPPacket(ctx, conn, rp)
  75. }
  76. }
  77. func (s *Server) handleTCPPacket(ctx Context, conn net.Conn, rp Packet) {
  78. defer func() {
  79. if r := recover(); r != nil {
  80. if s.errorHandler != nil {
  81. buf := make([]byte, 1<<12)
  82. n := runtime.Stack(buf, false)
  83. s.errorHandler(errors.New(string(buf[:n])))
  84. }
  85. }
  86. }()
  87. if rp.Operator == OPERATOR_HEARTBEAT && s.pingHandler != nil {
  88. s.pingHandler.Handle(ctx)
  89. ctx.Success(nil)
  90. }
  91. handler, ok := s.router.handlerContainer[rp.Operator]
  92. if !ok {
  93. ctx.Error(StatusInternalServerError, "server don't register your request.")
  94. }
  95. if rm, ok := s.router.routerMiddleware[rp.Operator]; ok {
  96. for _, v := range rm {
  97. ctx = v.Handle(ctx)
  98. }
  99. }
  100. for _, v := range s.router.middleware {
  101. ctx = v.Handle(ctx)
  102. if tm, ok := v.(TerminateMiddleware); ok {
  103. tm.Terminate(ctx)
  104. }
  105. }
  106. handler.Handle(ctx)
  107. ctx.Success(nil) // If it don't call the function of Success or Error, deal it by default
  108. }
  109. // 开始运行Tcp服务
  110. func (s *Server) RunTCP(name, address string) error {
  111. tcpAddr, err := net.ResolveTCPAddr(name, address)
  112. if err != nil {
  113. return err
  114. }
  115. listener, err := net.ListenTCP(name, tcpAddr)
  116. if err != nil {
  117. return err
  118. }
  119. defer listener.Close()
  120. fmt.Printf("tcp server running on %s\n", address)
  121. for {
  122. conn, err := listener.AcceptTCP()
  123. if err != nil {
  124. continue
  125. }
  126. go s.handleTCPConnection(conn)
  127. }
  128. }