123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- package export
-
- import (
- "context"
- "fmt"
- "io"
- "net"
- "strconv"
- "time"
-
- "github.com/wpajqz/linker"
- "github.com/wpajqz/linker/plugins"
- "github.com/wpajqz/linker/utils/convert"
- )
-
- // 处理客户端连接
- func (c *Client) handleConnection(conn net.Conn) (err error) {
- ctx, cancel := context.WithCancel(context.Background())
- defer func(cancel context.CancelFunc) { cancel() }(cancel)
-
- q := make(chan bool, 2)
- go func(conn net.Conn) {
- err = c.handleSendPackets(ctx, conn)
- if err != nil {
- q <- true
- }
- }(conn)
-
- go func(conn net.Conn) {
- err = c.handleReceivedPackets(conn)
- if err != nil {
- q <- true
- }
- }(conn)
-
- <-q
-
- return
- }
-
- // 对发送的数据包进行处理
- func (c *Client) handleSendPackets(ctx context.Context, conn net.Conn) error {
- for {
- select {
- case p := <-c.packet:
- _, err := conn.Write(p.Bytes())
- if err != nil {
- return err
- }
-
- if c.timeout != 0 {
- conn.SetWriteDeadline(time.Now().Add(c.timeout))
- }
- case <-ctx.Done():
- return nil
- }
- }
- }
-
- // 对接收到的数据包进行处理
- func (c *Client) handleReceivedPackets(conn net.Conn) error {
- var (
- bType = make([]byte, 4)
- bSequence = make([]byte, 8)
- bHeaderLength = make([]byte, 4)
- bBodyLength = make([]byte, 4)
- sequence int64
- headerLength uint32
- bodyLength uint32
- pacLen uint32
- )
-
- for {
- if c.timeout != 0 {
- conn.SetReadDeadline(time.Now().Add(c.timeout))
- }
-
- if n, err := io.ReadFull(conn, bType); err != nil && n != 4 {
- return err
- }
-
- if n, err := io.ReadFull(conn, bSequence); err != nil && n != 8 {
- return err
- }
-
- if n, err := io.ReadFull(conn, bHeaderLength); err != nil && n != 4 {
- return err
- }
-
- if n, err := io.ReadFull(conn, bBodyLength); err != nil && n != 4 {
- return err
- }
-
- nType := convert.BytesToUint32(bType)
- sequence = convert.BytesToInt64(bSequence)
- headerLength = convert.BytesToUint32(bHeaderLength)
- bodyLength = convert.BytesToUint32(bBodyLength)
-
- pacLen = headerLength + bodyLength + 20
- if pacLen > MaxPayload {
- return fmt.Errorf("the packet is big than %v", strconv.Itoa(MaxPayload))
- }
-
- header := make([]byte, headerLength)
- if n, err := io.ReadFull(conn, header); err != nil && n != int(headerLength) {
- return err
- }
-
- body := make([]byte, bodyLength)
- if n, err := io.ReadFull(conn, body); err != nil && n != int(bodyLength) {
- return err
- }
-
- receive, err := linker.NewPacket(nType, sequence, header, body, []linker.PacketPlugin{
- &plugins.Decryption{},
- })
- if err != nil {
- return err
- }
-
- c.response.Header = receive.Header
- c.response.Body = receive.Body
-
- operator := int64(nType) + sequence
- if handler, ok := c.handlerContainer.Load(operator); ok {
- if v, ok := handler.(Handler); ok {
- v.Handle(receive.Header, receive.Body)
- }
- }
- }
- }
|