Files
inp2025/tcp/router.go
2025-11-10 14:52:16 +08:00

227 lines
4.5 KiB
Go

package tcp
import (
"encoding/json"
"io"
"net"
"strings"
"go.uber.org/zap"
)
type Handler func(w ResponseWriter, req *Request) error
type Middleware func(next Handler) Handler
type Router struct {
middlewares []Middleware
routes map[Method]map[string]Handler
listenAddr string
doneCh chan error
}
func NewRouter() *Router {
return &Router{
routes: make(map[Method]map[string]Handler),
doneCh: make(chan error),
}
}
type Group struct {
pRouter *Router
pGroup *Group
path string
middlewares []Middleware
}
func (self *Router) Group(path string) *Group {
return &Group{
pRouter: self,
pGroup: nil,
path: path,
}
}
func (self *Group) Group(path string) *Group {
return &Group{
pRouter: nil,
pGroup: self,
path: path,
}
}
func (self *Router) Use(middleware Middleware) *Router {
self.middlewares = append(self.middlewares, middleware)
return self
}
func (self *Group) Use(middleware Middleware) *Group {
self.middlewares = append(self.middlewares, middleware)
return self
}
func (self *Router) Register(method Method, route string, handler Handler) {
_, ok := self.routes[method]
if !ok {
self.routes[method] = make(map[string]Handler)
}
for _, middleware := range self.middlewares {
handler = middleware(handler)
}
self.routes[method][route] = handler
}
func (self *Group) Register(method Method, route string, handler Handler) {
for _, middleware := range self.middlewares {
handler = middleware(handler)
}
if self.pRouter != nil {
self.pRouter.Register(method, self.path+route, handler)
}
if self.pGroup != nil {
self.pGroup.Register(method, self.path+route, handler)
}
}
func (self *Router) GET(route string, handler Handler) {
self.Register(MethodGET, route, handler)
}
func (self *Group) GET(route string, handler Handler) {
self.Register(MethodGET, route, handler)
}
func (self *Router) POST(route string, handler Handler) {
self.Register(MethodPOST, route, handler)
}
func (self *Group) POST(route string, handler Handler) {
self.Register(MethodPOST, route, handler)
}
func (self *Router) PUT(route string, handler Handler) {
self.Register(MethodPUT, route, handler)
}
func (self *Group) PUT(route string, handler Handler) {
self.Register(MethodPUT, route, handler)
}
func (self *Router) DELETE(route string, handler Handler) {
self.Register(MethodDELETE, route, handler)
}
func (self *Group) DELETE(route string, handler Handler) {
self.Register(MethodDELETE, route, handler)
}
func (self *Router) SOCKET(route string, handler Handler) {
self.Register(MethodSOCKET, route, handler)
}
func (self *Group) SOCKET(route string, handler Handler) {
self.Register(MethodSOCKET, route, handler)
}
func (self *Router) run(conn net.Conn, req *Request) {
handler, ok := self.routes[req.Method][req.Route]
if !ok {
zap.L().Warn("route not exist",
zap.String("route", req.Route))
return
}
w := NewResponseBuilder(conn)
err := handler(w, req)
if err != nil {
zap.L().Error("failed to run handler",
zap.Error(err))
return
}
res := w.Build(req)
header, err := res.Header()
if err != nil {
zap.L().Error("failed to marshal header",
zap.Error(err))
return
}
if err := sendFrame(conn, header); err != nil {
zap.L().Error("failed to write header",
zap.Error(err))
return
}
if err := sendFrame(conn, res.Body); err != nil {
zap.L().Error("failed to write body",
zap.Error(err))
return
}
}
func (self *Router) listen(listener net.Listener) {
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
self.doneCh <- err
return
}
rawHeader, err := readFrame(conn)
if err != nil {
if err == io.EOF {
self.doneCh <- nil
return
}
self.doneCh <- err
return
}
body, err := readFrame(conn)
if err != nil {
if err == io.EOF {
self.doneCh <- nil
return
}
self.doneCh <- err
return
}
var header RequestHeader
if err := json.Unmarshal(rawHeader, &header); err != nil {
self.doneCh <- err
return
}
req := NewRequest(conn, header, body)
go self.run(conn, req)
}
}
func (self *Router) WaitFor() error {
return <-self.doneCh
}
func (self *Router) ListenAddr() string {
conn, _ := net.Dial("udp", "8.8.8.8:80")
addr := conn.LocalAddr().(*net.UDPAddr).IP.String()
conn.Close()
ss2 := strings.Split(self.listenAddr, ":")
return addr + ":" + ss2[len(ss2)-1]
}
func (self *Router) Listen(addr string) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
self.listenAddr = listener.Addr().String()
go self.listen(listener)
return nil
}