Add convenience method to check if websockets required (#779)
* Add convenience method to check if websockets required * Add tests * Fix up tests for develop branch
This commit is contained in:
		
							
								
								
									
										10
									
								
								context.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								context.go
									
									
									
									
									
								
							@ -383,6 +383,16 @@ func (c *Context) ContentType() string {
 | 
				
			|||||||
	return filterFlags(c.requestHeader("Content-Type"))
 | 
						return filterFlags(c.requestHeader("Content-Type"))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// IsWebsocket returns true if the request headers indicate that a websocket
 | 
				
			||||||
 | 
					// handshake is being initiated by the client.
 | 
				
			||||||
 | 
					func (c *Context) IsWebsocket() bool {
 | 
				
			||||||
 | 
						if strings.Contains(strings.ToLower(c.requestHeader("Connection")), "upgrade") &&
 | 
				
			||||||
 | 
							strings.ToLower(c.requestHeader("Upgrade")) == "websocket" {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Context) requestHeader(key string) string {
 | 
					func (c *Context) requestHeader(key string) string {
 | 
				
			||||||
	if values, _ := c.Request.Header[key]; len(values) > 0 {
 | 
						if values, _ := c.Request.Header[key]; len(values) > 0 {
 | 
				
			||||||
		return values[0]
 | 
							return values[0]
 | 
				
			||||||
 | 
				
			|||||||
@ -814,3 +814,25 @@ func TestContextGolangContext(t *testing.T) {
 | 
				
			|||||||
	assert.Equal(t, c.Value("foo"), "bar")
 | 
						assert.Equal(t, c.Value("foo"), "bar")
 | 
				
			||||||
	assert.Nil(t, c.Value(1))
 | 
						assert.Nil(t, c.Value(1))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestWebsocketsRequired(t *testing.T) {
 | 
				
			||||||
 | 
						// Example request from spec: https://tools.ietf.org/html/rfc6455#section-1.2
 | 
				
			||||||
 | 
						c, _ := CreateTestContext(httptest.NewRecorder())
 | 
				
			||||||
 | 
						c.Request, _ = http.NewRequest("GET", "/chat", nil)
 | 
				
			||||||
 | 
						c.Request.Header.Set("Host", "server.example.com")
 | 
				
			||||||
 | 
						c.Request.Header.Set("Upgrade", "websocket")
 | 
				
			||||||
 | 
						c.Request.Header.Set("Connection", "Upgrade")
 | 
				
			||||||
 | 
						c.Request.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
 | 
				
			||||||
 | 
						c.Request.Header.Set("Origin", "http://example.com")
 | 
				
			||||||
 | 
						c.Request.Header.Set("Sec-WebSocket-Protocol", "chat, superchat")
 | 
				
			||||||
 | 
						c.Request.Header.Set("Sec-WebSocket-Version", "13")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.True(t, c.IsWebsocket())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Normal request, no websocket required.
 | 
				
			||||||
 | 
						c, _ = CreateTestContext(httptest.NewRecorder())
 | 
				
			||||||
 | 
						c.Request, _ = http.NewRequest("GET", "/chat", nil)
 | 
				
			||||||
 | 
						c.Request.Header.Set("Host", "server.example.com")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.False(t, c.IsWebsocket())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user