From 0d73abcefdc4090aa27b87a799c504551af841c8 Mon Sep 17 00:00:00 2001 From: Fabricio Date: Sun, 11 Nov 2018 14:54:35 -0200 Subject: [PATCH] change reverse proxy core --- main.go | 138 +++++++++++++++++++++---------------------- main_test.go | 161 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+), 72 deletions(-) create mode 100644 main_test.go diff --git a/main.go b/main.go index f118872..c5a5fa3 100644 --- a/main.go +++ b/main.go @@ -16,30 +16,21 @@ import ( "github.com/googollee/go-socket.io" ) -type transport struct { - http.RoundTripper - maxItems int - currItemID int -} - var captures Captures var dashboardSocket socketio.Socket func main() { config := ReadConfig() + startCapture(config) +} - transp := &transport{ - RoundTripper: http.DefaultTransport, - maxItems: config.MaxCaptures, - currItemID: 0, - } - - http.Handle("/", getProxyHandler(config.TargetURL, transp)) - http.Handle("/socket.io/", getDashboardSocketHandler(config)) - http.Handle(config.DashboardPath, getDashboardHandler()) - http.Handle(config.DashboardClearPath, getDashboardClearHandler()) - http.Handle(config.DashboardItemInfoPath, getDashboardItemInfoHandler()) +func startCapture(config Config) { + http.Handle("/", proxyHandler(config)) + http.Handle("/socket.io/", dashboardSocketHandler(config)) + http.Handle(config.DashboardPath, dashboardHandler()) + http.Handle(config.DashboardClearPath, dashboardClearHandler()) + http.Handle(config.DashboardItemInfoPath, dashboardItemInfoHandler()) proxyHost := fmt.Sprintf("http://localhost:%s", config.ProxyPort) @@ -49,7 +40,7 @@ func main() { fmt.Println(http.ListenAndServe(":"+config.ProxyPort, nil)) } -func getDashboardSocketHandler(config Config) http.Handler { +func dashboardSocketHandler(config Config) http.Handler { server, err := socketio.NewServer(nil) if err != nil { fmt.Println("socket server error", err) @@ -65,7 +56,7 @@ func getDashboardSocketHandler(config Config) http.Handler { return server } -func getDashboardClearHandler() http.Handler { +func dashboardClearHandler() http.Handler { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { captures = nil emitToDashboard(captures) @@ -73,14 +64,14 @@ func getDashboardClearHandler() http.Handler { }) } -func getDashboardHandler() http.Handler { +func dashboardHandler() http.Handler { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { res.Header().Add("Content-Type", "text/html") res.Write([]byte(dashboardHTML)) }) } -func getDashboardItemInfoHandler() http.Handler { +func dashboardItemInfoHandler() http.Handler { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { idStr := req.URL.Path[strings.LastIndex(req.URL.Path, "/")+1:] idInt, _ := strconv.Atoi(idStr) @@ -94,67 +85,70 @@ func getDashboardItemInfoHandler() http.Handler { }) } -func getProxyHandler(targetURL string, transp *transport) http.Handler { - url, _ := url.Parse(targetURL) - proxy := httputil.NewSingleHostReverseProxy(url) - proxy.Transport = transp - return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - request.Host = request.URL.Host - proxy.ServeHTTP(response, request) +func proxyHandler(config Config) http.Handler { + url, _ := url.Parse(config.TargetURL) + captureID := 0 + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + req.Host = url.Host + req.URL.Host = url.Host + req.URL.Scheme = url.Scheme + + reqDump, err := dumpRequest(req) + if err != nil { + fmt.Printf("Could not dump request: %v\n", err) + } + + proxy := httputil.NewSingleHostReverseProxy(url) + proxy.ModifyResponse = func(res *http.Response) error { + resDump, err := dumpResponse(res) + if err != nil { + return fmt.Errorf("Could not dump response: %v", err) + } + captureID++ + capture := Capture{ + ID: captureID, + Path: req.URL.Path, + Method: req.Method, + Status: res.StatusCode, + Request: string(reqDump), + Response: string(resDump), + } + captures.Add(capture) + captures.RemoveLastAfterReaching(config.MaxCaptures) + emitToDashboard(captures) + return nil + } + proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { + fmt.Printf("uh oh | %v | %s\n", err, req.URL) + } + proxy.ServeHTTP(rw, req) }) } -func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { - - reqDump, err := dumpRequest(req) - if err != nil { - return nil, err - } - - res, err := t.RoundTripper.RoundTrip(req) - if err != nil { - return nil, fmt.Errorf("uh oh | %v | %s", err, req.URL) - } - - resDump, err := dumpResponse(res) - if err != nil { - return nil, err - } - - capture := Capture{ - ID: t.NewItemID(), - Path: req.URL.Path, - Method: req.Method, - Status: res.StatusCode, - Request: string(reqDump), - Response: string(resDump), - } - - captures.Add(capture) - captures.RemoveLastAfterReaching(t.maxItems) - emitToDashboard(captures) - return res, nil -} - -func (t *transport) NewItemID() int { - t.currItemID++ - return t.currItemID -} - func dumpRequest(req *http.Request) ([]byte, error) { + if req.Header.Get("Content-Encoding") == "gzip" { + var originalBody bytes.Buffer + tee := io.TeeReader(req.Body, &originalBody) + reader, _ := gzip.NewReader(tee) + req.Body = ioutil.NopCloser(reader) + reqDump, err := httputil.DumpRequest(req, true) + req.Body = ioutil.NopCloser(&originalBody) + return reqDump, err + } return httputil.DumpRequest(req, true) } func dumpResponse(res *http.Response) ([]byte, error) { - var originalBody bytes.Buffer - reader := io.TeeReader(res.Body, &originalBody) if res.Header.Get("Content-Encoding") == "gzip" { - reader, _ = gzip.NewReader(reader) + var originalBody bytes.Buffer + tee := io.TeeReader(res.Body, &originalBody) + reader, _ := gzip.NewReader(tee) + res.Body = ioutil.NopCloser(reader) + resDump, err := httputil.DumpResponse(res, true) + res.Body = ioutil.NopCloser(&originalBody) + return resDump, err } - res.Body = ioutil.NopCloser(reader) - resDump, err := httputil.DumpResponse(res, true) - res.Body = ioutil.NopCloser(&originalBody) - return resDump, err + return httputil.DumpResponse(res, true) } func emitToDashboard(captures Captures) { diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..58095fe --- /dev/null +++ b/main_test.go @@ -0,0 +1,161 @@ +package main + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// Test the reverse proxy handler +func TestProxyHandler(t *testing.T) { + tt := []TestCase{ + GetRequest(), + PostRequest(), + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + service := httptest.NewServer(http.HandlerFunc(tc.service)) + capture := httptest.NewServer(proxyHandler(Config{TargetURL: service.URL})) + + resp := tc.request(capture.URL) + + tc.test(t, resp) + + resp.Body.Close() + capture.Close() + service.Close() + }) + } +} + +type TestCase struct { + name string + request func(string) *http.Response + service func(http.ResponseWriter, *http.Request) + test func(*testing.T, *http.Response) +} + +func GetRequest() TestCase { + msg := "hello" + return TestCase{ + name: "GetRequest", + request: func(url string) *http.Response { + res, _ := http.Get(url) + return res + }, + service: func(rw http.ResponseWriter, req *http.Request) { + fmt.Fprint(rw, string(msg)) + }, + test: func(t *testing.T, res *http.Response) { + body, _ := ioutil.ReadAll(res.Body) + if string(body) != msg { + t.Error("Wrong Body Response") + } + }, + } +} + +func PostRequest() TestCase { + msg := "hello" + return TestCase{ + name: "PostRequest", + request: func(url string) *http.Response { + res, _ := http.Post(url, "text/plain", strings.NewReader(msg)) + return res + }, + service: func(rw http.ResponseWriter, req *http.Request) { + io.Copy(rw, req.Body) + }, + test: func(t *testing.T, res *http.Response) { + body, _ := ioutil.ReadAll(res.Body) + if string(body) != msg { + t.Error("Wrong Body Response") + } + }, + } +} + +func TestDumpRequest(t *testing.T) { + msg := "hello" + + req, err := http.NewRequest(http.MethodPost, "http://localhost:9000/", strings.NewReader(msg)) + if err != nil { + t.Errorf("Could not create request: %v", err) + } + + body, err := dumpRequest(req) + + if err != nil { + t.Errorf("Dump Request error: %v", err) + } + if !strings.Contains(string(body), msg) { + t.Errorf("Dump Request is not '%s'", msg) + } +} + +func TestDumpRequestGzip(t *testing.T) { + msg := "hello" + + req, err := http.NewRequest(http.MethodPost, "http://localhost:9000/", strings.NewReader(gzipStr(msg))) + req.Header.Set("Content-Encoding", "gzip") + if err != nil { + t.Errorf("Could not create request: %v", err) + } + + body, err := dumpRequest(req) + + if err != nil { + t.Errorf("Dump Request Gzip error: %v", err) + } + if !strings.Contains(string(body), msg) { + t.Errorf("Dump Request Gzip is not '%s'", msg) + } +} + +func TestDumpResponse(t *testing.T) { + msg := "hello" + + res := &http.Response{Body: ioutil.NopCloser(strings.NewReader(msg))} + + body, err := dumpResponse(res) + + if err != nil { + t.Errorf("Dump Response Error: %v", err) + } + if !strings.Contains(string(body), msg) { + t.Errorf("Dump Response is not '%s'", msg) + } +} + +func TestDumpResponseGzip(t *testing.T) { + msg := "hello" + + // make a response + h := make(http.Header) + h.Set("Content-Encoding", "gzip") + res := &http.Response{Header: h, Body: ioutil.NopCloser(strings.NewReader(gzipStr(msg)))} + + // dump it + body, err := dumpResponse(res) + + if err != nil { + t.Errorf("Dump Response error: %v", err) + } + if !strings.Contains(string(body), msg) { + t.Error("Not hello") + } +} + +func gzipStr(str string) string { + var buff bytes.Buffer + g := gzip.NewWriter(&buff) + io.WriteString(g, str) + g.Close() + return buff.String() +}