package tcp import ( "encoding/json" "io" "net" "strings" "go.uber.org/zap" ) type Handler func(w ResponseWriter, req *Request) error type Middleware func(next Handler) Handler type Router struct { middlewares []Middleware routes map[Method]map[string]Handler listenAddr string doneCh chan error } func NewRouter() *Router { return &Router{ routes: make(map[Method]map[string]Handler), doneCh: make(chan error), } } type Group struct { pRouter *Router pGroup *Group path string middlewares []Middleware } func (self *Router) Group(path string) *Group { return &Group{ pRouter: self, pGroup: nil, path: path, } } func (self *Group) Group(path string) *Group { return &Group{ pRouter: nil, pGroup: self, path: path, } } func (self *Router) Use(middleware Middleware) *Router { self.middlewares = append(self.middlewares, middleware) return self } func (self *Group) Use(middleware Middleware) *Group { self.middlewares = append(self.middlewares, middleware) return self } func (self *Router) Register(method Method, route string, handler Handler) { _, ok := self.routes[method] if !ok { self.routes[method] = make(map[string]Handler) } for _, middleware := range self.middlewares { handler = middleware(handler) } self.routes[method][route] = handler } func (self *Group) Register(method Method, route string, handler Handler) { for _, middleware := range self.middlewares { handler = middleware(handler) } if self.pRouter != nil { self.pRouter.Register(method, self.path+route, handler) } if self.pGroup != nil { self.pGroup.Register(method, self.path+route, handler) } } func (self *Router) GET(route string, handler Handler) { self.Register(MethodGET, route, handler) } func (self *Group) GET(route string, handler Handler) { self.Register(MethodGET, route, handler) } func (self *Router) POST(route string, handler Handler) { self.Register(MethodPOST, route, handler) } func (self *Group) POST(route string, handler Handler) { self.Register(MethodPOST, route, handler) } func (self *Router) PUT(route string, handler Handler) { self.Register(MethodPUT, route, handler) } func (self *Group) PUT(route string, handler Handler) { self.Register(MethodPUT, route, handler) } func (self *Router) DELETE(route string, handler Handler) { self.Register(MethodDELETE, route, handler) } func (self *Group) DELETE(route string, handler Handler) { self.Register(MethodDELETE, route, handler) } func (self *Router) SOCKET(route string, handler Handler) { self.Register(MethodSOCKET, route, handler) } func (self *Group) SOCKET(route string, handler Handler) { self.Register(MethodSOCKET, route, handler) } func (self *Router) run(conn net.Conn, req *Request) { handler, ok := self.routes[req.Method][req.Route] if !ok { zap.L().Warn("route not exist", zap.String("route", req.Route)) return } w := NewResponseBuilder(conn) err := handler(w, req) if err != nil { zap.L().Error("failed to run handler", zap.Error(err)) return } res := w.Build(req) header, err := res.Header() if err != nil { zap.L().Error("failed to marshal header", zap.Error(err)) return } if err := sendFrame(conn, header); err != nil { zap.L().Error("failed to write header", zap.Error(err)) return } if err := sendFrame(conn, res.Body); err != nil { zap.L().Error("failed to write body", zap.Error(err)) return } } func (self *Router) listen(listener net.Listener) { defer listener.Close() for { conn, err := listener.Accept() if err != nil { self.doneCh <- err return } rawHeader, err := readFrame(conn) if err != nil { if err == io.EOF { self.doneCh <- nil return } self.doneCh <- err return } body, err := readFrame(conn) if err != nil { if err == io.EOF { self.doneCh <- nil return } self.doneCh <- err return } var header RequestHeader if err := json.Unmarshal(rawHeader, &header); err != nil { self.doneCh <- err return } req := NewRequest(conn, header, body) go self.run(conn, req) } } func (self *Router) WaitFor() error { return <-self.doneCh } func (self *Router) ListenAddr() string { conn, _ := net.Dial("udp", "8.8.8.8:80") addr := conn.LocalAddr().(*net.UDPAddr).IP.String() conn.Close() ss2 := strings.Split(self.listenAddr, ":") return addr + ":" + ss2[len(ss2)-1] } func (self *Router) Listen(addr string) error { listener, err := net.Listen("tcp", addr) if err != nil { return err } self.listenAddr = listener.Addr().String() go self.listen(listener) return nil }