diff --git a/framework/core.go b/framework/core.go index 7336ddc..3f3d827 100644 --- a/framework/core.go +++ b/framework/core.go @@ -8,7 +8,8 @@ import ( // Core is the core struct of the framework type Core struct { - router map[string]*Trie + router map[string]*Trie + middlewares []ControllerHandler } // NewCore initialize the Core. @@ -28,40 +29,44 @@ func NewCore() *Core { } // Get is a simple get router -func (c *Core) Get(url string, handler ControllerHandler) { - upperUrl := strings.ToUpper(url) - if err := c.router["GET"].AddRouter(upperUrl, handler); err != nil { +func (c *Core) Get(url string, handlers ...ControllerHandler) { + allHandlers := append(c.middlewares, handlers...) + if err := c.router["GET"].AddRouter(url, allHandlers); err != nil { log.Println(err) } } // Post is a simple post router -func (c *Core) Post(url string, handler ControllerHandler) { - upperUrl := strings.ToUpper(url) - if err := c.router["POST"].AddRouter(upperUrl, handler); err != nil { +func (c *Core) Post(url string, handlers ...ControllerHandler) { + allHandlers := append(c.middlewares, handlers...) + if err := c.router["POST"].AddRouter(url, allHandlers); err != nil { log.Println(err) } } // Put is a simple put router -func (c *Core) Put(url string, handler ControllerHandler) { - upperUrl := strings.ToUpper(url) - if err := c.router["PUT"].AddRouter(upperUrl, handler); err != nil { +func (c *Core) Put(url string, handlers ...ControllerHandler) { + allHandlers := append(c.middlewares, handlers...) + if err := c.router["PUT"].AddRouter(url, allHandlers); err != nil { log.Println(err) } } // Delete is a simple delete router -func (c *Core) Delete(url string, handler ControllerHandler) { - upperUrl := strings.ToUpper(url) - if err := c.router["DELETE"].AddRouter(upperUrl, handler); err != nil { +func (c *Core) Delete(url string, handlers ...ControllerHandler) { + allHandlers := append(c.middlewares, handlers...) + if err := c.router["DELETE"].AddRouter(url, allHandlers); err != nil { log.Println(err) } } +// Use registers middlewares +func (c *Core) Use(middlewares ...ControllerHandler) { + c.middlewares = append(c.middlewares, middlewares...) +} + // FindRouteByRequest finds route using the request func (c *Core) FindRouteByRequest(r *http.Request) []ControllerHandler { - upperUri := strings.ToUpper(r.URL.Path) upperMethod := strings.ToUpper(r.Method) mapper, ok := c.router[upperMethod] @@ -70,7 +75,7 @@ func (c *Core) FindRouteByRequest(r *http.Request) []ControllerHandler { return nil } - controllers := mapper.FindRoute(upperUri) + controllers := mapper.FindRoute(r.URL.Path) if controllers == nil { log.Printf("URI %q is not recognized\n", r.URL.Path) return nil diff --git a/framework/group.go b/framework/group.go index f07c02f..9354450 100644 --- a/framework/group.go +++ b/framework/group.go @@ -2,16 +2,18 @@ package framework // IGroup prefix routes type IGroup interface { - Get(string, ControllerHandler) - Post(string, ControllerHandler) - Put(string, ControllerHandler) - Delete(string, ControllerHandler) + Get(string, ...ControllerHandler) + Post(string, ...ControllerHandler) + Put(string, ...ControllerHandler) + Delete(string, ...ControllerHandler) + Use(...ControllerHandler) } // Group is the implementation of IGroup interface type Group struct { - core *Core - prefix string + core *Core + prefix string + middlewares []ControllerHandler } // NewGroup create a new prefix group @@ -23,21 +25,30 @@ func NewGroup(core *Core, prefix string) *Group { } // Get is a simple get router of the group -func (g *Group) Get(url string, handler ControllerHandler) { - g.core.Get(g.prefix+url, handler) +func (g *Group) Get(url string, handlers ...ControllerHandler) { + allHandlers := append(g.middlewares, handlers...) + g.core.Get(g.prefix+url, allHandlers...) } // Post is a simple post router of the group -func (g *Group) Post(url string, handler ControllerHandler) { - g.core.Post(g.prefix+url, handler) +func (g *Group) Post(url string, handlers ...ControllerHandler) { + allHandlers := append(g.middlewares, handlers...) + g.core.Post(g.prefix+url, allHandlers...) } // Put is a simple put router of the group -func (g *Group) Put(url string, handler ControllerHandler) { - g.core.Put(g.prefix+url, handler) +func (g *Group) Put(url string, handlers ...ControllerHandler) { + allHandlers := append(g.middlewares, handlers...) + g.core.Put(g.prefix+url, allHandlers...) } // Delete is a simple delete router of the group -func (g *Group) Delete(url string, handler ControllerHandler) { - g.core.Delete(g.prefix+url, handler) +func (g *Group) Delete(url string, handlers ...ControllerHandler) { + allHandlers := append(g.middlewares, handlers...) + g.core.Delete(g.prefix+url, allHandlers...) +} + +// Use registers middlewares +func (g *Group) Use(middlewares ...ControllerHandler) { + g.middlewares = append(g.middlewares, middlewares...) } diff --git a/framework/middleware/test.go b/framework/middleware/test.go new file mode 100644 index 0000000..8a068ee --- /dev/null +++ b/framework/middleware/test.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "log" + + "git.vinchent.xyz/vinchent/go-web/framework" +) + +func Test1() framework.ControllerHandler { + return func(c *framework.Context) error { + log.Println("middleware test1 pre") + c.Next() + log.Println("middleware test1 post") + return nil + } +} + +func Test2() framework.ControllerHandler { + return func(c *framework.Context) error { + log.Println("middleware test2 pre") + c.Next() + log.Println("middleware test2 post") + return nil + } +} + +func Test3() framework.ControllerHandler { + return func(c *framework.Context) error { + log.Println("middleware test3 pre") + c.Next() + log.Println("middleware test3 post") + return nil + } +} diff --git a/framework/trie.go b/framework/trie.go index a118811..aad7924 100644 --- a/framework/trie.go +++ b/framework/trie.go @@ -15,6 +15,7 @@ func NewTrie() *Trie { } func (t *Trie) FindRoute(uri string) []ControllerHandler { + uri = strings.ToUpper(uri) uri = strings.TrimPrefix(uri, "/") if uri == "" { return t.root.handlers @@ -28,11 +29,12 @@ func (t *Trie) FindRoute(uri string) []ControllerHandler { return found.handlers } -func (t *Trie) AddRouter(uri string, handler ControllerHandler) error { +func (t *Trie) AddRouter(uri string, handlers []ControllerHandler) error { + uri = strings.ToUpper(uri) uri = strings.TrimPrefix(uri, "/") if uri == "" { t.root.isLast = true - t.root.handlers = append(t.root.handlers, handler) + t.root.handlers = append(t.root.handlers, handlers...) return nil } @@ -44,7 +46,7 @@ func (t *Trie) AddRouter(uri string, handler ControllerHandler) error { } // The route does not exist, add it to the tree - err := t.root.addRoute(upperUri, handler) + err := t.root.addRoute(upperUri, handlers) if err != nil { return err } @@ -109,7 +111,7 @@ func (n *node) findRoute(uri string) *node { return nil } -func (n *node) addRoute(uri string, handler ControllerHandler) error { +func (n *node) addRoute(uri string, handlers []ControllerHandler) error { splitted := strings.SplitN(uri, "/", 2) splittedLen := len(splitted) isLast := splittedLen == 1 @@ -125,12 +127,12 @@ func (n *node) addRoute(uri string, handler ControllerHandler) error { } else { // otherwise, set the child child.isLast = true - child.handlers = append(child.handlers, handler) + child.handlers = append(child.handlers, handlers...) return nil } } // More segments to check - return child.addRoute(splitted[1], handler) + return child.addRoute(splitted[1], handlers) } } @@ -138,14 +140,14 @@ func (n *node) addRoute(uri string, handler ControllerHandler) error { new := newNode(splitted[0]) if isLast { // this is the end - new.handlers = append(new.handlers, handler) + new.handlers = append(new.handlers, handlers...) new.isLast = true n.children = append(n.children, new) return nil } // continue new.isLast = false - err := new.addRoute(splitted[1], handler) + err := new.addRoute(splitted[1], handlers) if err != nil { return err } diff --git a/routes.go b/routes.go index a28e611..106fcc1 100644 --- a/routes.go +++ b/routes.go @@ -1,12 +1,17 @@ package main -import "git.vinchent.xyz/vinchent/go-web/framework" +import ( + "git.vinchent.xyz/vinchent/go-web/framework" + "git.vinchent.xyz/vinchent/go-web/framework/middleware" +) func registerRouter(core *framework.Core) { + core.Use(middleware.Test1(), middleware.Test2()) core.Get("/user/login", UserLoginController) subjectApi := core.Group("/subject") { + subjectApi.Use(middleware.Test3()) subjectApi.Delete("/:id", SubjectDelController) subjectApi.Put("/:id", SubjectUpdateController) subjectApi.Get("/:id", SubjectGetController)