diff --git a/capture.go b/capture.go index 354469a..e191840 100644 --- a/capture.go +++ b/capture.go @@ -1,5 +1,26 @@ package main +import ( + "strconv" + "sync" +) + +var captureID int +var captures CaptureList + +type CaptureRepository interface { + Insert(capture Capture) + RemoveAll() + Find(captureID string) *Capture + FindAll() []Capture +} + +type CaptureList struct { + items []Capture + mux sync.Mutex + maxItems int +} + type Capture struct { ID int `json:"id"` Path string `json:"path"` @@ -16,27 +37,54 @@ type CaptureMetadata struct { Status int `json:"status"` } -type Captures []Capture - -func (items *Captures) Add(capture Capture) { - *items = append(*items, capture) -} - -func (items *Captures) RemoveLastAfterReaching(maxItems int) { - if len(*items) > maxItems { - *items = (*items)[1:] +func (c *Capture) Metadata() CaptureMetadata { + return CaptureMetadata{ + ID: c.ID, + Path: c.Path, + Method: c.Method, + Status: c.Status, } } -func (items *Captures) MetadataOnly() []CaptureMetadata { - refs := make([]CaptureMetadata, len(*items)) - for i, item := range *items { - refs[i] = CaptureMetadata{ - ID: item.ID, - Path: item.Path, - Method: item.Method, - Status: item.Status, +func NewCapturesRepository(maxItems int) CaptureRepository { + return &CaptureList{ + maxItems: maxItems, + } +} + +func (c *CaptureList) Insert(capture Capture) { + c.mux.Lock() + defer c.mux.Unlock() + capture.ID = newID() + c.items = append(c.items, capture) + if len(c.items) > c.maxItems { + c.items = c.items[1:] + } +} + +func (c *CaptureList) Find(captureID string) *Capture { + c.mux.Lock() + defer c.mux.Unlock() + idInt, _ := strconv.Atoi(captureID) + for _, c := range c.items { + if c.ID == idInt { + return &c } } - return refs + return nil +} + +func (c *CaptureList) RemoveAll() { + c.mux.Lock() + defer c.mux.Unlock() + c.items = nil +} + +func (c *CaptureList) FindAll() []Capture { + return c.items +} + +func newID() int { + captureID++ + return captureID } diff --git a/main.go b/main.go index bbefd63..2bd6dce 100644 --- a/main.go +++ b/main.go @@ -8,17 +8,14 @@ import ( "io" "io/ioutil" "net/http" + "net/http/httptest" "net/http/httputil" "net/url" - "strconv" "strings" - "sync" "github.com/googollee/go-socket.io" ) -var captures Captures - var dashboardSocket socketio.Socket func main() { @@ -27,11 +24,14 @@ func main() { } func startCapture(config Config) { - http.Handle("/", proxyHandler(config)) - http.Handle("/socket.io/", dashboardSocketHandler(config)) + + repo := NewCapturesRepository(config.MaxCaptures) + + http.Handle("/", NewRecorder(repo, proxyHandler(config.TargetURL))) + http.Handle("/socket.io/", dashboardSocketHandler(repo, config)) http.Handle(config.DashboardPath, dashboardHandler()) - http.Handle(config.DashboardClearPath, dashboardClearHandler()) - http.Handle(config.DashboardItemInfoPath, dashboardItemInfoHandler()) + http.Handle(config.DashboardClearPath, dashboardClearHandler(repo)) + http.Handle(config.DashboardItemInfoPath, dashboardItemInfoHandler(repo)) captureHost := fmt.Sprintf("http://localhost:%s", config.ProxyPort) @@ -41,7 +41,7 @@ func startCapture(config Config) { fmt.Println(http.ListenAndServe(":"+config.ProxyPort, nil)) } -func dashboardSocketHandler(config Config) http.Handler { +func dashboardSocketHandler(repo CaptureRepository, config Config) http.Handler { server, err := socketio.NewServer(nil) if err != nil { fmt.Printf("socket server error: %v\n", err) @@ -49,7 +49,7 @@ func dashboardSocketHandler(config Config) http.Handler { server.On("connection", func(so socketio.Socket) { dashboardSocket = so dashboardSocket.Emit("config", config) - emitToDashboard(captures) + emitToDashboard(repo.FindAll()) }) server.On("error", func(so socketio.Socket, err error) { fmt.Printf("socket error: %v\n", err) @@ -57,10 +57,10 @@ func dashboardSocketHandler(config Config) http.Handler { return server } -func dashboardClearHandler() http.Handler { +func dashboardClearHandler(repo CaptureRepository) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - captures = nil - emitToDashboard(captures) + repo.RemoveAll() + emitToDashboard(nil) rw.WriteHeader(http.StatusOK) }) } @@ -72,71 +72,75 @@ func dashboardHandler() http.Handler { }) } -func dashboardItemInfoHandler() http.Handler { +func dashboardItemInfoHandler(repo CaptureRepository) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - idStr := req.URL.Path[strings.LastIndex(req.URL.Path, "/")+1:] - idInt, _ := strconv.Atoi(idStr) - for _, c := range captures { - if c.ID == idInt { - rw.Header().Add("Content-Type", "application/json") - json.NewEncoder(rw).Encode(c) - break - } + id := req.URL.Path[strings.LastIndex(req.URL.Path, "/")+1:] + capture := repo.Find(id) + if capture == nil { + http.Error(rw, "Item Not Found", http.StatusNotFound) + return } + rw.Header().Add("Content-Type", "application/json") + json.NewEncoder(rw).Encode(capture) }) } -func proxyHandler(config Config) http.Handler { - url, _ := url.Parse(config.TargetURL) - captureID := 0 - mux := sync.Mutex{} +func NewRecorder(repo CaptureRepository, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - req.Host = url.Host - req.URL.Host = url.Host - req.URL.Scheme = url.Scheme - reqDump, err := dumpRequest(req) if err != nil { fmt.Printf("could not dump request: %v\n", err) } - proxy := httputil.NewSingleHostReverseProxy(url) - proxy.ModifyResponse = func(res *http.Response) error { - resDump, err := dumpResponse(res) - if err != nil { - return fmt.Errorf("could not dump response: %v", err) - } - mux.Lock() - captureID++ - capture := Capture{ - ID: captureID, - Path: req.URL.Path, - Method: req.Method, - Status: res.StatusCode, - Request: string(reqDump), - Response: string(resDump), - } - captures.Add(capture) - captures.RemoveLastAfterReaching(config.MaxCaptures) - mux.Unlock() - emitToDashboard(captures) - return nil + rec := httptest.NewRecorder() + + next.ServeHTTP(rec, req) + + for k, v := range rec.HeaderMap { + rw.Header()[k] = v } - proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { - fmt.Printf("uh oh | %v | %s\n", err, req.URL) + rw.WriteHeader(rec.Code) + rw.Write(rec.Body.Bytes()) + + res := rec.Result() + resDump, err := dumpResponse(res) + if err != nil { + fmt.Printf("could not dump response: %v\n", err) } + capture := Capture{ + Path: req.URL.Path, + Method: req.Method, + Status: res.StatusCode, + Request: string(reqDump), + Response: string(resDump), + } + repo.Insert(capture) + emitToDashboard(repo.FindAll()) + }) +} + +func proxyHandler(URL string) http.Handler { + url, _ := url.Parse(URL) + proxy := httputil.NewSingleHostReverseProxy(url) + proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { + fmt.Printf("uh oh | %v | %s %s\n", err, req.Method, req.URL) + } + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + req.Host = url.Host + req.URL.Host = url.Host + req.URL.Scheme = url.Scheme proxy.ServeHTTP(rw, req) }) } func dumpRequest(req *http.Request) ([]byte, error) { if req.Header.Get("Content-Encoding") == "gzip" { - var originalBody bytes.Buffer - tee := io.TeeReader(req.Body, &originalBody) - reader, _ := gzip.NewReader(tee) + var reqBody []byte + req.Body, reqBody = drain(req.Body) + reader, _ := gzip.NewReader(bytes.NewReader(reqBody)) req.Body = ioutil.NopCloser(reader) reqDump, err := httputil.DumpRequest(req, true) - req.Body = ioutil.NopCloser(&originalBody) + req.Body = ioutil.NopCloser(bytes.NewReader(reqBody)) return reqDump, err } return httputil.DumpRequest(req, true) @@ -144,19 +148,30 @@ func dumpRequest(req *http.Request) ([]byte, error) { func dumpResponse(res *http.Response) ([]byte, error) { if res.Header.Get("Content-Encoding") == "gzip" { - var originalBody bytes.Buffer - tee := io.TeeReader(res.Body, &originalBody) - reader, _ := gzip.NewReader(tee) + var resBody []byte + res.Body, resBody = drain(res.Body) + reader, _ := gzip.NewReader(bytes.NewReader(resBody)) res.Body = ioutil.NopCloser(reader) resDump, err := httputil.DumpResponse(res, true) - res.Body = ioutil.NopCloser(&originalBody) + res.Body = ioutil.NopCloser(bytes.NewReader(resBody)) return resDump, err } return httputil.DumpResponse(res, true) } -func emitToDashboard(captures Captures) { - if dashboardSocket != nil { - dashboardSocket.Emit("captures", captures.MetadataOnly()) - } +func drain(b io.ReadCloser) (io.ReadCloser, []byte) { + all, _ := ioutil.ReadAll(b) + b.Close() + return ioutil.NopCloser(bytes.NewReader(all)), all +} + +func emitToDashboard(captures []Capture) { + if dashboardSocket == nil { + return + } + metadatas := make([]CaptureMetadata, len(captures)) + for i, capture := range captures { + metadatas[i] = capture.Metadata() + } + dashboardSocket.Emit("captures", metadatas) } diff --git a/main_test.go b/main_test.go index 25425e5..43b0403 100644 --- a/main_test.go +++ b/main_test.go @@ -18,6 +18,7 @@ import ( // Test the reverse proxy handler func TestProxyHandler(t *testing.T) { + // given tt := []TestCase{ GetRequest(), PostRequest(), @@ -25,10 +26,12 @@ func TestProxyHandler(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { service := httptest.NewServer(http.HandlerFunc(tc.service)) - capture := httptest.NewServer(proxyHandler(Config{TargetURL: service.URL})) + capture := httptest.NewServer(proxyHandler(service.URL)) + // when resp := tc.request(capture.URL) + // then tc.test(t, resp) resp.Body.Close() @@ -88,13 +91,16 @@ func PostRequest() TestCase { func TestDumpRequest(t *testing.T) { msg := "hello" + // given req, err := http.NewRequest(http.MethodPost, "http://localhost:9000/", strings.NewReader(msg)) if err != nil { t.Errorf("Could not create request: %v", err) } + // when body, err := dumpRequest(req) + // then if err != nil { t.Errorf("Dump Request error: %v", err) } @@ -106,14 +112,17 @@ func TestDumpRequest(t *testing.T) { func TestDumpRequestGzip(t *testing.T) { msg := "hello" + // given req, err := http.NewRequest(http.MethodPost, "http://localhost:9000/", strings.NewReader(gzipStr(msg))) req.Header.Set("Content-Encoding", "gzip") if err != nil { t.Errorf("Could not create request: %v", err) } + // when body, err := dumpRequest(req) + // then if err != nil { t.Errorf("Dump Request Gzip error: %v", err) } @@ -125,10 +134,13 @@ func TestDumpRequestGzip(t *testing.T) { func TestDumpResponse(t *testing.T) { msg := "hello" + // given res := &http.Response{Body: ioutil.NopCloser(strings.NewReader(msg))} + // when body, err := dumpResponse(res) + // then if err != nil { t.Errorf("Dump Response Error: %v", err) } @@ -140,14 +152,15 @@ func TestDumpResponse(t *testing.T) { func TestDumpResponseGzip(t *testing.T) { msg := "hello" - // make a response + // given h := make(http.Header) h.Set("Content-Encoding", "gzip") res := &http.Response{Header: h, Body: ioutil.NopCloser(strings.NewReader(gzipStr(msg)))} - // dump it + // when body, err := dumpResponse(res) + // then if err != nil { t.Errorf("Dump Response error: %v", err) } @@ -160,17 +173,21 @@ func TestCaptureIDConcurrence(t *testing.T) { // This test bothers me + // given + interactions := 1000 - // Startup servers service := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) rw.WriteHeader(http.StatusOK) })) - capture := httptest.NewServer(proxyHandler(Config{TargetURL: service.URL, MaxCaptures: interactions})) + repo := NewCapturesRepository(interactions) + capture := httptest.NewServer(NewRecorder(repo, proxyHandler(service.URL))) defer service.Close() defer capture.Close() + // when + // Starts go routines so that captureID is incremented concurrently within proxyHandler() wg := &sync.WaitGroup{} wg.Add(interactions) @@ -178,14 +195,20 @@ func TestCaptureIDConcurrence(t *testing.T) { go func() { _, err := http.Get(capture.URL) if err != nil { - t.Errorf("Request Failed: %v", err) + t.Fatalf("Request Failed: %v", err) } wg.Done() }() } wg.Wait() + // then + // Tests if captures IDs are sequential + captures := repo.FindAll() + if len(captures) == 0 { + t.Fatalf("No captures found") + } ids := make([]int, len(captures)) for i := 0; i < len(captures); i++ { ids[i] = captures[i].ID