123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- package linker
-
- import (
- "fmt"
- "io"
- "net/http"
- "runtime"
- "time"
-
- "errors"
-
- "sync"
-
- "github.com/gorilla/websocket"
- "github.com/wpajqz/linker/utils/convert"
- )
-
- func (s *Server) handleWebSocketConnection(conn *websocket.Conn) error {
- wsn := &webSocketConn{mutex: sync.Mutex{}, conn: conn}
- var ctx Context = &ContextWebsocket{Conn: wsn}
-
- if s.constructHandler != nil {
- s.constructHandler.Handle(ctx)
- }
-
- defer func() {
- if s.destructHandler != nil {
- s.destructHandler.Handle(ctx)
- }
-
- conn.Close()
- }()
-
- var (
- bType = make([]byte, 4)
- bSequence = make([]byte, 8)
- bHeaderLength = make([]byte, 4)
- bBodyLength = make([]byte, 4)
- sequence int64
- headerLength uint32
- bodyLength uint32
- )
-
- conn.SetReadLimit(MaxPayload)
-
- for {
- if s.config.Timeout != 0 {
- conn.SetReadDeadline(time.Now().Add(s.config.Timeout))
- }
-
- _, r, err := conn.NextReader()
- if err != nil {
- return err
- }
-
- if _, err := io.ReadFull(r, bType); err != nil {
- return err
- }
-
- if _, err := io.ReadFull(r, bSequence); err != nil {
- return err
- }
-
- if _, err := io.ReadFull(r, bHeaderLength); err != nil {
- return err
- }
-
- if _, err := io.ReadFull(r, bBodyLength); err != nil {
- return err
- }
-
- sequence = convert.BytesToInt64(bSequence)
- headerLength = convert.BytesToUint32(bHeaderLength)
- bodyLength = convert.BytesToUint32(bBodyLength)
- pacLen := headerLength + bodyLength + uint32(20)
-
- if pacLen > s.config.MaxPayload {
- _, file, line, _ := runtime.Caller(1)
- return SystemError{time.Now(), file, line, "packet larger than MaxPayload"}
- }
-
- header := make([]byte, headerLength)
- if _, err := io.ReadFull(r, header); err != nil {
- return err
-
- }
-
- body := make([]byte, bodyLength)
- if _, err := io.ReadFull(r, body); err != nil {
- return err
- }
-
- rp, err := NewPacket(convert.BytesToUint32(bType), sequence, header, body, s.config.PluginForPacketReceiver)
-
- if err != nil {
- return err
- }
-
- ctx = NewContextWebsocket(wsn, rp.Operator, rp.Sequence, rp.Header, rp.Body, s.config)
- go s.handleWebSocketPacket(ctx, conn, rp)
- }
- }
-
- func (s *Server) handleWebSocketPacket(ctx Context, conn *websocket.Conn, rp Packet) {
- defer func() {
- if r := recover(); r != nil {
- if s.errorHandler != nil {
- buf := make([]byte, 1<<12)
- n := runtime.Stack(buf, false)
- s.errorHandler(errors.New(string(buf[:n])))
- }
- }
- }()
-
- if rp.Operator == OPERATOR_HEARTBEAT && s.pingHandler != nil {
- s.pingHandler.Handle(ctx)
- ctx.Success(nil)
- }
-
- handler, ok := s.router.handlerContainer[rp.Operator]
- if !ok {
- ctx.Error(StatusInternalServerError, "server don't register your request.")
- }
-
- if rm, ok := s.router.routerMiddleware[rp.Operator]; ok {
- for _, v := range rm {
- ctx = v.Handle(ctx)
- }
- }
-
- for _, v := range s.router.middleware {
- ctx = v.Handle(ctx)
- if tm, ok := v.(TerminateMiddleware); ok {
- tm.Terminate(ctx)
- }
- }
-
- handler.Handle(ctx)
- ctx.Success(nil) // If it don't call the function of Success or Error, deal it by default
- }
-
- // 开始运行webocket服务
- func (s *Server) RunWebSocket(address string) error {
- http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
- var upgrade = websocket.Upgrader{
- HandshakeTimeout: s.config.Timeout,
- ReadBufferSize: s.config.ReadBufferSize,
- WriteBufferSize: s.config.WriteBufferSize,
- EnableCompression: true,
- CheckOrigin: func(r *http.Request) bool {
- return true
- },
- }
-
- conn, err := upgrade.Upgrade(w, r, nil)
- if err != nil {
- w.Write([]byte(err.Error()))
- return
- }
-
- go s.handleWebSocketConnection(conn)
- })
-
- fmt.Printf("websocket server running on %s\n", address)
-
- return http.ListenAndServe(address, nil)
- }
|