168 lines
3.3 KiB
Go
168 lines
3.3 KiB
Go
package utils
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"time"
|
|
|
|
"gitea.konchin.com/ytshih/inp2025/types"
|
|
"github.com/spf13/viper"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const (
|
|
BUFFER_SIZE int = 1024
|
|
MAGIC_NUMBER int = 114514
|
|
|
|
LISTEN_TIMEOUT = 200 * time.Millisecond
|
|
)
|
|
|
|
type UDPReqType int
|
|
|
|
const (
|
|
UDPReqTypeData UDPReqType = iota
|
|
UDPReqTypePingRequest
|
|
UDPReqTypePingReply
|
|
)
|
|
|
|
type UDPPayload struct {
|
|
MagicNumber int `msgpack:"magicNumber"`
|
|
Endpoint string `msgpack:"endpoint"`
|
|
Type UDPReqType `msgpack:"type"`
|
|
|
|
Data string `msgpack:"data"`
|
|
}
|
|
|
|
func ListenUDPData(
|
|
port int,
|
|
dataCh chan string,
|
|
) (string, types.ShutdownFunc, error) {
|
|
return ListenUDP(port, dataCh, nil)
|
|
}
|
|
|
|
func Ping(endpoints []string) ([]string, error) {
|
|
pingCh := make(chan string)
|
|
local, shutdown, err := ListenUDP(0, nil, pingCh)
|
|
if err != nil {
|
|
return []string{}, err
|
|
}
|
|
defer shutdown()
|
|
|
|
for _, endpoint := range endpoints {
|
|
SendRawPayload(endpoint, UDPPayload{
|
|
MagicNumber: MAGIC_NUMBER,
|
|
Endpoint: local,
|
|
Type: UDPReqTypePingRequest,
|
|
})
|
|
}
|
|
|
|
doneCh := make(chan struct{})
|
|
go func() {
|
|
time.Sleep(LISTEN_TIMEOUT)
|
|
doneCh <- struct{}{}
|
|
}()
|
|
ret := []string{}
|
|
for {
|
|
select {
|
|
case <-doneCh:
|
|
return ret, nil
|
|
case endpoint := <-pingCh:
|
|
ret = append(ret, endpoint)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ListenUDP(
|
|
port int,
|
|
dataCh chan string,
|
|
pingCh chan string,
|
|
) (string, types.ShutdownFunc, error) {
|
|
conn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to listen udp, %w", err)
|
|
}
|
|
addr, _ := net.ResolveUDPAddr("udp4", conn.LocalAddr().String())
|
|
local := fmt.Sprintf("%s:%d", viper.GetString("host"), addr.Port)
|
|
|
|
go func() {
|
|
for {
|
|
buffer := make([]byte, BUFFER_SIZE)
|
|
|
|
n, _, err := conn.ReadFromUDP(buffer)
|
|
if err != nil {
|
|
zap.L().Error("fuck udp",
|
|
zap.Error(err))
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
|
|
var payload UDPPayload
|
|
err = msgpack.Unmarshal(buffer[:n], &payload)
|
|
if err == nil && payload.MagicNumber == MAGIC_NUMBER {
|
|
switch payload.Type {
|
|
case UDPReqTypeData:
|
|
if dataCh != nil {
|
|
dataCh <- payload.Data
|
|
}
|
|
case UDPReqTypePingRequest:
|
|
SendRawPayload(payload.Endpoint, UDPPayload{
|
|
MagicNumber: MAGIC_NUMBER,
|
|
Endpoint: local,
|
|
Type: UDPReqTypePingReply,
|
|
})
|
|
case UDPReqTypePingReply:
|
|
if pingCh != nil {
|
|
pingCh <- payload.Endpoint
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return local, func() { conn.Close() }, nil
|
|
}
|
|
|
|
func SendPayload(
|
|
local, remote string,
|
|
data any,
|
|
) error {
|
|
sdata, err := msgpack.Marshal(data)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal data, %w", err)
|
|
}
|
|
|
|
return SendRawPayload(remote, UDPPayload{
|
|
MagicNumber: MAGIC_NUMBER,
|
|
Endpoint: local,
|
|
Type: UDPReqTypeData,
|
|
Data: string(sdata),
|
|
})
|
|
}
|
|
|
|
func SendRawPayload(
|
|
endpoint string,
|
|
payload UDPPayload,
|
|
) error {
|
|
conn, err := net.Dial("udp4", endpoint)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to dial endpoint, %w", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
b, err := msgpack.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal payload, %w", err)
|
|
}
|
|
|
|
_, err = conn.Write(b)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to send payload, %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|