From e22d4ea6227594df71edd1dc26cb49727bb37f2e Mon Sep 17 00:00:00 2001 From: pei0804 Date: Sat, 23 Mar 2019 21:39:50 +0900 Subject: [PATCH] fix --- swagger.go | 17 ++++++++++++++++- swagger_test.go | 23 +++++++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/swagger.go b/swagger.go index 60d8ec3..d1dfeea 100644 --- a/swagger.go +++ b/swagger.go @@ -70,7 +70,22 @@ func CustomWrapHandler(config *Config, h *webdav.Handler) gin.HandlerFunc { // DisablingWrapHandler turn handler off // if specified environment variable passed -func DisablingWrapHandler(config *Config, h *webdav.Handler, envName string) gin.HandlerFunc { +func DisablingWrapHandler(h *webdav.Handler, envName string) gin.HandlerFunc { + eFlag := os.Getenv(envName) + if eFlag != "" { + return func(c *gin.Context) { + // Simulate behavior when route unspecified and + // return 404 HTTP code + c.String(404, "") + } + } + + return WrapHandler(h) +} + +// DisablingCustomWrapHandler turn handler off +// if specified environment variable passed +func DisablingCustomWrapHandler(config *Config, h *webdav.Handler, envName string) gin.HandlerFunc { eFlag := os.Getenv(envName) if eFlag != "" { return func(c *gin.Context) { diff --git a/swagger_test.go b/swagger_test.go index ae31d8d..0a9ff18 100644 --- a/swagger_test.go +++ b/swagger_test.go @@ -47,7 +47,7 @@ func TestDisablingWrapHandler(t *testing.T) { router := gin.New() disablingKey := "SWAGGER_DISABLE" - router.GET("/simple/*any", DisablingWrapHandler(&Config{}, swaggerFiles.Handler, disablingKey)) + router.GET("/simple/*any", DisablingWrapHandler(swaggerFiles.Handler, disablingKey)) w1 := performRequest("GET", "/simple/index.html", router) assert.Equal(t, 200, w1.Code) @@ -63,7 +63,7 @@ func TestDisablingWrapHandler(t *testing.T) { os.Setenv(disablingKey, "true") - router.GET("/disabling/*any", DisablingWrapHandler(&Config{}, swaggerFiles.Handler, disablingKey)) + router.GET("/disabling/*any", DisablingWrapHandler(swaggerFiles.Handler, disablingKey)) w11 := performRequest("GET", "/disabling/index.html", router) assert.Equal(t, 404, w11.Code) @@ -78,6 +78,25 @@ func TestDisablingWrapHandler(t *testing.T) { assert.Equal(t, 404, w44.Code) } +func TestDisablingCustomWrapHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + disablingKey := "SWAGGER_DISABLE2" + + router.GET("/simple/*any", DisablingCustomWrapHandler(&Config{}, swaggerFiles.Handler, disablingKey)) + + w1 := performRequest("GET", "/simple/index.html", router) + assert.Equal(t, 200, w1.Code) + + os.Setenv(disablingKey, "true") + + router.GET("/disabling/*any", DisablingCustomWrapHandler(&Config{}, swaggerFiles.Handler, disablingKey)) + + w11 := performRequest("GET", "/disabling/index.html", router) + assert.Equal(t, 404, w11.Code) +} + func performRequest(method, target string, router *gin.Engine) *httptest.ResponseRecorder { r := httptest.NewRequest(method, target, nil) w := httptest.NewRecorder()