Compare commits

..

No commits in common. "3b16d6b16a52b77edbfbe579c38eda5b1d0b9dda" and "1aa9b78bdceb8a7055b285b2ece2c39d3937a17b" have entirely different histories.

8 changed files with 45 additions and 202 deletions

View File

@ -8,8 +8,7 @@ import (
// Core is the core struct of the framework // Core is the core struct of the framework
type Core struct { type Core struct {
router map[string]*Trie router map[string]*Trie
middlewares []ControllerHandler
} }
// NewCore initialize the Core. // NewCore initialize the Core.
@ -29,44 +28,40 @@ func NewCore() *Core {
} }
// Get is a simple get router // Get is a simple get router
func (c *Core) Get(url string, handlers ...ControllerHandler) { func (c *Core) Get(url string, handler ControllerHandler) {
allHandlers := append(c.middlewares, handlers...) upperUrl := strings.ToUpper(url)
if err := c.router["GET"].AddRouter(url, allHandlers); err != nil { if err := c.router["GET"].AddRouter(upperUrl, handler); err != nil {
log.Println(err) log.Println(err)
} }
} }
// Post is a simple post router // Post is a simple post router
func (c *Core) Post(url string, handlers ...ControllerHandler) { func (c *Core) Post(url string, handler ControllerHandler) {
allHandlers := append(c.middlewares, handlers...) upperUrl := strings.ToUpper(url)
if err := c.router["POST"].AddRouter(url, allHandlers); err != nil { if err := c.router["POST"].AddRouter(upperUrl, handler); err != nil {
log.Println(err) log.Println(err)
} }
} }
// Put is a simple put router // Put is a simple put router
func (c *Core) Put(url string, handlers ...ControllerHandler) { func (c *Core) Put(url string, handler ControllerHandler) {
allHandlers := append(c.middlewares, handlers...) upperUrl := strings.ToUpper(url)
if err := c.router["PUT"].AddRouter(url, allHandlers); err != nil { if err := c.router["PUT"].AddRouter(upperUrl, handler); err != nil {
log.Println(err) log.Println(err)
} }
} }
// Delete is a simple delete router // Delete is a simple delete router
func (c *Core) Delete(url string, handlers ...ControllerHandler) { func (c *Core) Delete(url string, handler ControllerHandler) {
allHandlers := append(c.middlewares, handlers...) upperUrl := strings.ToUpper(url)
if err := c.router["DELETE"].AddRouter(url, allHandlers); err != nil { if err := c.router["DELETE"].AddRouter(upperUrl, handler); err != nil {
log.Println(err) log.Println(err)
} }
} }
// Use registers middlewares
func (c *Core) Use(middlewares ...ControllerHandler) {
c.middlewares = append(c.middlewares, middlewares...)
}
// FindRouteByRequest finds route using the request // FindRouteByRequest finds route using the request
func (c *Core) FindRouteByRequest(r *http.Request) []ControllerHandler { func (c *Core) FindRouteByRequest(r *http.Request) []ControllerHandler {
upperUri := strings.ToUpper(r.URL.Path)
upperMethod := strings.ToUpper(r.Method) upperMethod := strings.ToUpper(r.Method)
mapper, ok := c.router[upperMethod] mapper, ok := c.router[upperMethod]
@ -75,7 +70,7 @@ func (c *Core) FindRouteByRequest(r *http.Request) []ControllerHandler {
return nil return nil
} }
controllers := mapper.FindRoute(r.URL.Path) controllers := mapper.FindRoute(upperUri)
if controllers == nil { if controllers == nil {
log.Printf("URI %q is not recognized\n", r.URL.Path) log.Printf("URI %q is not recognized\n", r.URL.Path)
return nil return nil

View File

@ -2,18 +2,16 @@ package framework
// IGroup prefix routes // IGroup prefix routes
type IGroup interface { type IGroup interface {
Get(string, ...ControllerHandler) Get(string, ControllerHandler)
Post(string, ...ControllerHandler) Post(string, ControllerHandler)
Put(string, ...ControllerHandler) Put(string, ControllerHandler)
Delete(string, ...ControllerHandler) Delete(string, ControllerHandler)
Use(...ControllerHandler)
} }
// Group is the implementation of IGroup interface // Group is the implementation of IGroup interface
type Group struct { type Group struct {
core *Core core *Core
prefix string prefix string
middlewares []ControllerHandler
} }
// NewGroup create a new prefix group // NewGroup create a new prefix group
@ -25,30 +23,21 @@ func NewGroup(core *Core, prefix string) *Group {
} }
// Get is a simple get router of the group // Get is a simple get router of the group
func (g *Group) Get(url string, handlers ...ControllerHandler) { func (g *Group) Get(url string, handler ControllerHandler) {
allHandlers := append(g.middlewares, handlers...) g.core.Get(g.prefix+url, handler)
g.core.Get(g.prefix+url, allHandlers...)
} }
// Post is a simple post router of the group // Post is a simple post router of the group
func (g *Group) Post(url string, handlers ...ControllerHandler) { func (g *Group) Post(url string, handler ControllerHandler) {
allHandlers := append(g.middlewares, handlers...) g.core.Post(g.prefix+url, handler)
g.core.Post(g.prefix+url, allHandlers...)
} }
// Put is a simple put router of the group // Put is a simple put router of the group
func (g *Group) Put(url string, handlers ...ControllerHandler) { func (g *Group) Put(url string, handler ControllerHandler) {
allHandlers := append(g.middlewares, handlers...) g.core.Put(g.prefix+url, handler)
g.core.Put(g.prefix+url, allHandlers...)
} }
// Delete is a simple delete router of the group // Delete is a simple delete router of the group
func (g *Group) Delete(url string, handlers ...ControllerHandler) { func (g *Group) Delete(url string, handler ControllerHandler) {
allHandlers := append(g.middlewares, handlers...) g.core.Delete(g.prefix+url, handler)
g.core.Delete(g.prefix+url, allHandlers...)
}
// Use registers middlewares
func (g *Group) Use(middlewares ...ControllerHandler) {
g.middlewares = append(g.middlewares, middlewares...)
} }

View File

@ -1,21 +0,0 @@
package middleware
import (
"net/http"
"git.vinchent.xyz/vinchent/go-web/framework"
)
// Recovery is a middleware that recovers from the panic
func Recovery() framework.ControllerHandler {
return func(c *framework.Context) error {
defer func() {
if err := recover(); err != nil {
c.WriteJSON(http.StatusInternalServerError, err)
}
}()
c.Next()
return nil
}
}

View File

@ -1,34 +0,0 @@
package middleware
import (
"log"
"git.vinchent.xyz/vinchent/go-web/framework"
)
func Test1() framework.ControllerHandler {
return func(c *framework.Context) error {
log.Println("middleware test1 pre")
c.Next()
log.Println("middleware test1 post")
return nil
}
}
func Test2() framework.ControllerHandler {
return func(c *framework.Context) error {
log.Println("middleware test2 pre")
c.Next()
log.Println("middleware test2 post")
return nil
}
}
func Test3() framework.ControllerHandler {
return func(c *framework.Context) error {
log.Println("middleware test3 pre")
c.Next()
log.Println("middleware test3 post")
return nil
}
}

View File

@ -41,7 +41,6 @@ func Timeout(d time.Duration) framework.ControllerHandler {
log.Println("finish") log.Println("finish")
case <-durationCtx.Done(): case <-durationCtx.Done():
c.SetHasTimeout() c.SetHasTimeout()
c.GetResponseWriter().WriteHeader(http.StatusRequestTimeout)
c.GetResponseWriter().Write([]byte("time out")) c.GetResponseWriter().Write([]byte("time out"))
} }
return nil return nil

View File

@ -1,8 +1,6 @@
package middleware package middleware
import ( import (
"bytes"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -15,89 +13,13 @@ func TestTimeout(t *testing.T) {
t.Run("Test timeout handler", func(t *testing.T) { t.Run("Test timeout handler", func(t *testing.T) {
timeoutHandler := Timeout(1 * time.Millisecond) timeoutHandler := Timeout(1 * time.Millisecond)
longHandler := func(c *framework.Context) error { request := httptest.NewRequest(http.MethodGet, "/", nil)
time.Sleep(2 * time.Millisecond) response := httptest.NewRecorder()
return nil c := framework.NewContext(response, request)
err := timeoutHandler(c)
if err != nil {
t.Fatal(err)
} }
res := prepareMiddlewareTest(t, timeoutHandler, longHandler)
assertCode(t, res.StatusCode, http.StatusRequestTimeout)
assertBody(t, res.Body, "time out")
})
t.Run("Test no timeout", func(t *testing.T) {
timeoutHandler := Timeout(2 * time.Millisecond)
quickHandler := func(c *framework.Context) error {
// time.Sleep(1 * time.Millisecond)
c.WriteJSON(http.StatusOK, "ok")
return nil
}
res := prepareMiddlewareTest(t, timeoutHandler, quickHandler)
assertCode(t, res.StatusCode, http.StatusOK)
}) })
} }
func TestRecover(t *testing.T) {
t.Run("Test panic", func(t *testing.T) {
recoverer := Recovery()
panicHandler := func(c *framework.Context) error {
panic("panic")
}
res := prepareMiddlewareTest(t, recoverer, panicHandler)
assertCode(t, res.StatusCode, http.StatusInternalServerError)
})
t.Run("Test no panic", func(t *testing.T) {
recoverer := Recovery()
normalHandler := func(c *framework.Context) error {
c.WriteJSON(http.StatusOK, "ok")
return nil
}
res := prepareMiddlewareTest(t, recoverer, normalHandler)
assertCode(t, res.StatusCode, http.StatusOK)
})
}
func prepareMiddlewareTest(
t testing.TB,
mid framework.ControllerHandler,
in framework.ControllerHandler,
) *http.Response {
t.Helper()
request := httptest.NewRequest(http.MethodGet, "/", nil)
response := httptest.NewRecorder()
c := framework.NewContext(response, request)
c.SetHandlers([]framework.ControllerHandler{in})
err := mid(c)
if err != nil {
t.Fatal(err)
}
res := response.Result()
return res
}
func assertCode(t testing.TB, got int, want int) {
t.Helper()
if got != want {
t.Errorf("status code got %d, want %d", got, want)
}
}
func assertBody(t testing.TB, got io.Reader, want string) {
t.Helper()
buf, _ := io.ReadAll(got)
if cmp := bytes.Compare(buf, []byte(want)); cmp != 0 {
t.Errorf("got %q, want %q", string(buf), want)
}
}

View File

@ -15,7 +15,6 @@ func NewTrie() *Trie {
} }
func (t *Trie) FindRoute(uri string) []ControllerHandler { func (t *Trie) FindRoute(uri string) []ControllerHandler {
uri = strings.ToUpper(uri)
uri = strings.TrimPrefix(uri, "/") uri = strings.TrimPrefix(uri, "/")
if uri == "" { if uri == "" {
return t.root.handlers return t.root.handlers
@ -29,12 +28,11 @@ func (t *Trie) FindRoute(uri string) []ControllerHandler {
return found.handlers return found.handlers
} }
func (t *Trie) AddRouter(uri string, handlers []ControllerHandler) error { func (t *Trie) AddRouter(uri string, handler ControllerHandler) error {
uri = strings.ToUpper(uri)
uri = strings.TrimPrefix(uri, "/") uri = strings.TrimPrefix(uri, "/")
if uri == "" { if uri == "" {
t.root.isLast = true t.root.isLast = true
t.root.handlers = append(t.root.handlers, handlers...) t.root.handlers = append(t.root.handlers, handler)
return nil return nil
} }
@ -46,7 +44,7 @@ func (t *Trie) AddRouter(uri string, handlers []ControllerHandler) error {
} }
// The route does not exist, add it to the tree // The route does not exist, add it to the tree
err := t.root.addRoute(upperUri, handlers) err := t.root.addRoute(upperUri, handler)
if err != nil { if err != nil {
return err return err
} }
@ -111,7 +109,7 @@ func (n *node) findRoute(uri string) *node {
return nil return nil
} }
func (n *node) addRoute(uri string, handlers []ControllerHandler) error { func (n *node) addRoute(uri string, handler ControllerHandler) error {
splitted := strings.SplitN(uri, "/", 2) splitted := strings.SplitN(uri, "/", 2)
splittedLen := len(splitted) splittedLen := len(splitted)
isLast := splittedLen == 1 isLast := splittedLen == 1
@ -127,12 +125,12 @@ func (n *node) addRoute(uri string, handlers []ControllerHandler) error {
} else { } else {
// otherwise, set the child // otherwise, set the child
child.isLast = true child.isLast = true
child.handlers = append(child.handlers, handlers...) child.handlers = append(child.handlers, handler)
return nil return nil
} }
} }
// More segments to check // More segments to check
return child.addRoute(splitted[1], handlers) return child.addRoute(splitted[1], handler)
} }
} }
@ -140,14 +138,14 @@ func (n *node) addRoute(uri string, handlers []ControllerHandler) error {
new := newNode(splitted[0]) new := newNode(splitted[0])
if isLast { if isLast {
// this is the end // this is the end
new.handlers = append(new.handlers, handlers...) new.handlers = append(new.handlers, handler)
new.isLast = true new.isLast = true
n.children = append(n.children, new) n.children = append(n.children, new)
return nil return nil
} }
// continue // continue
new.isLast = false new.isLast = false
err := new.addRoute(splitted[1], handlers) err := new.addRoute(splitted[1], handler)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,17 +1,12 @@
package main package main
import ( import "git.vinchent.xyz/vinchent/go-web/framework"
"git.vinchent.xyz/vinchent/go-web/framework"
"git.vinchent.xyz/vinchent/go-web/framework/middleware"
)
func registerRouter(core *framework.Core) { func registerRouter(core *framework.Core) {
core.Use(middleware.Test1(), middleware.Test2())
core.Get("/user/login", UserLoginController) core.Get("/user/login", UserLoginController)
subjectApi := core.Group("/subject") subjectApi := core.Group("/subject")
{ {
subjectApi.Use(middleware.Test3())
subjectApi.Delete("/:id", SubjectDelController) subjectApi.Delete("/:id", SubjectDelController)
subjectApi.Put("/:id", SubjectUpdateController) subjectApi.Put("/:id", SubjectUpdateController)
subjectApi.Get("/:id", SubjectGetController) subjectApi.Get("/:id", SubjectGetController)