better encapsulation

This commit is contained in:
Fabricio 2018-11-22 19:45:20 -02:00
parent ff6587ce2b
commit ecbfb4f67a
3 changed files with 175 additions and 89 deletions

View file

@ -1,5 +1,26 @@
package main
import (
"strconv"
"sync"
)
var captureID int
var captures CaptureList
type CaptureRepository interface {
Insert(capture Capture)
RemoveAll()
Find(captureID string) *Capture
FindAll() []Capture
}
type CaptureList struct {
items []Capture
mux sync.Mutex
maxItems int
}
type Capture struct {
ID int `json:"id"`
Path string `json:"path"`
@ -16,27 +37,54 @@ type CaptureMetadata struct {
Status int `json:"status"`
}
type Captures []Capture
func (items *Captures) Add(capture Capture) {
*items = append(*items, capture)
}
func (items *Captures) RemoveLastAfterReaching(maxItems int) {
if len(*items) > maxItems {
*items = (*items)[1:]
func (c *Capture) Metadata() CaptureMetadata {
return CaptureMetadata{
ID: c.ID,
Path: c.Path,
Method: c.Method,
Status: c.Status,
}
}
func (items *Captures) MetadataOnly() []CaptureMetadata {
refs := make([]CaptureMetadata, len(*items))
for i, item := range *items {
refs[i] = CaptureMetadata{
ID: item.ID,
Path: item.Path,
Method: item.Method,
Status: item.Status,
func NewCapturesRepository(maxItems int) CaptureRepository {
return &CaptureList{
maxItems: maxItems,
}
}
func (c *CaptureList) Insert(capture Capture) {
c.mux.Lock()
defer c.mux.Unlock()
capture.ID = newID()
c.items = append(c.items, capture)
if len(c.items) > c.maxItems {
c.items = c.items[1:]
}
}
func (c *CaptureList) Find(captureID string) *Capture {
c.mux.Lock()
defer c.mux.Unlock()
idInt, _ := strconv.Atoi(captureID)
for _, c := range c.items {
if c.ID == idInt {
return &c
}
}
return refs
return nil
}
func (c *CaptureList) RemoveAll() {
c.mux.Lock()
defer c.mux.Unlock()
c.items = nil
}
func (c *CaptureList) FindAll() []Capture {
return c.items
}
func newID() int {
captureID++
return captureID
}

145
main.go
View file

@ -8,17 +8,14 @@ import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"sync"
"github.com/googollee/go-socket.io"
)
var captures Captures
var dashboardSocket socketio.Socket
func main() {
@ -27,11 +24,14 @@ func main() {
}
func startCapture(config Config) {
http.Handle("/", proxyHandler(config))
http.Handle("/socket.io/", dashboardSocketHandler(config))
repo := NewCapturesRepository(config.MaxCaptures)
http.Handle("/", NewRecorder(repo, proxyHandler(config.TargetURL)))
http.Handle("/socket.io/", dashboardSocketHandler(repo, config))
http.Handle(config.DashboardPath, dashboardHandler())
http.Handle(config.DashboardClearPath, dashboardClearHandler())
http.Handle(config.DashboardItemInfoPath, dashboardItemInfoHandler())
http.Handle(config.DashboardClearPath, dashboardClearHandler(repo))
http.Handle(config.DashboardItemInfoPath, dashboardItemInfoHandler(repo))
captureHost := fmt.Sprintf("http://localhost:%s", config.ProxyPort)
@ -41,7 +41,7 @@ func startCapture(config Config) {
fmt.Println(http.ListenAndServe(":"+config.ProxyPort, nil))
}
func dashboardSocketHandler(config Config) http.Handler {
func dashboardSocketHandler(repo CaptureRepository, config Config) http.Handler {
server, err := socketio.NewServer(nil)
if err != nil {
fmt.Printf("socket server error: %v\n", err)
@ -49,7 +49,7 @@ func dashboardSocketHandler(config Config) http.Handler {
server.On("connection", func(so socketio.Socket) {
dashboardSocket = so
dashboardSocket.Emit("config", config)
emitToDashboard(captures)
emitToDashboard(repo.FindAll())
})
server.On("error", func(so socketio.Socket, err error) {
fmt.Printf("socket error: %v\n", err)
@ -57,10 +57,10 @@ func dashboardSocketHandler(config Config) http.Handler {
return server
}
func dashboardClearHandler() http.Handler {
func dashboardClearHandler(repo CaptureRepository) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
captures = nil
emitToDashboard(captures)
repo.RemoveAll()
emitToDashboard(nil)
rw.WriteHeader(http.StatusOK)
})
}
@ -72,71 +72,75 @@ func dashboardHandler() http.Handler {
})
}
func dashboardItemInfoHandler() http.Handler {
func dashboardItemInfoHandler(repo CaptureRepository) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
idStr := req.URL.Path[strings.LastIndex(req.URL.Path, "/")+1:]
idInt, _ := strconv.Atoi(idStr)
for _, c := range captures {
if c.ID == idInt {
rw.Header().Add("Content-Type", "application/json")
json.NewEncoder(rw).Encode(c)
break
}
id := req.URL.Path[strings.LastIndex(req.URL.Path, "/")+1:]
capture := repo.Find(id)
if capture == nil {
http.Error(rw, "Item Not Found", http.StatusNotFound)
return
}
rw.Header().Add("Content-Type", "application/json")
json.NewEncoder(rw).Encode(capture)
})
}
func proxyHandler(config Config) http.Handler {
url, _ := url.Parse(config.TargetURL)
captureID := 0
mux := sync.Mutex{}
func NewRecorder(repo CaptureRepository, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
req.Host = url.Host
req.URL.Host = url.Host
req.URL.Scheme = url.Scheme
reqDump, err := dumpRequest(req)
if err != nil {
fmt.Printf("could not dump request: %v\n", err)
}
proxy := httputil.NewSingleHostReverseProxy(url)
proxy.ModifyResponse = func(res *http.Response) error {
resDump, err := dumpResponse(res)
if err != nil {
return fmt.Errorf("could not dump response: %v", err)
}
mux.Lock()
captureID++
capture := Capture{
ID: captureID,
Path: req.URL.Path,
Method: req.Method,
Status: res.StatusCode,
Request: string(reqDump),
Response: string(resDump),
}
captures.Add(capture)
captures.RemoveLastAfterReaching(config.MaxCaptures)
mux.Unlock()
emitToDashboard(captures)
return nil
rec := httptest.NewRecorder()
next.ServeHTTP(rec, req)
for k, v := range rec.HeaderMap {
rw.Header()[k] = v
}
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
fmt.Printf("uh oh | %v | %s\n", err, req.URL)
rw.WriteHeader(rec.Code)
rw.Write(rec.Body.Bytes())
res := rec.Result()
resDump, err := dumpResponse(res)
if err != nil {
fmt.Printf("could not dump response: %v\n", err)
}
capture := Capture{
Path: req.URL.Path,
Method: req.Method,
Status: res.StatusCode,
Request: string(reqDump),
Response: string(resDump),
}
repo.Insert(capture)
emitToDashboard(repo.FindAll())
})
}
func proxyHandler(URL string) http.Handler {
url, _ := url.Parse(URL)
proxy := httputil.NewSingleHostReverseProxy(url)
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
fmt.Printf("uh oh | %v | %s %s\n", err, req.Method, req.URL)
}
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
req.Host = url.Host
req.URL.Host = url.Host
req.URL.Scheme = url.Scheme
proxy.ServeHTTP(rw, req)
})
}
func dumpRequest(req *http.Request) ([]byte, error) {
if req.Header.Get("Content-Encoding") == "gzip" {
var originalBody bytes.Buffer
tee := io.TeeReader(req.Body, &originalBody)
reader, _ := gzip.NewReader(tee)
var reqBody []byte
req.Body, reqBody = drain(req.Body)
reader, _ := gzip.NewReader(bytes.NewReader(reqBody))
req.Body = ioutil.NopCloser(reader)
reqDump, err := httputil.DumpRequest(req, true)
req.Body = ioutil.NopCloser(&originalBody)
req.Body = ioutil.NopCloser(bytes.NewReader(reqBody))
return reqDump, err
}
return httputil.DumpRequest(req, true)
@ -144,19 +148,30 @@ func dumpRequest(req *http.Request) ([]byte, error) {
func dumpResponse(res *http.Response) ([]byte, error) {
if res.Header.Get("Content-Encoding") == "gzip" {
var originalBody bytes.Buffer
tee := io.TeeReader(res.Body, &originalBody)
reader, _ := gzip.NewReader(tee)
var resBody []byte
res.Body, resBody = drain(res.Body)
reader, _ := gzip.NewReader(bytes.NewReader(resBody))
res.Body = ioutil.NopCloser(reader)
resDump, err := httputil.DumpResponse(res, true)
res.Body = ioutil.NopCloser(&originalBody)
res.Body = ioutil.NopCloser(bytes.NewReader(resBody))
return resDump, err
}
return httputil.DumpResponse(res, true)
}
func emitToDashboard(captures Captures) {
if dashboardSocket != nil {
dashboardSocket.Emit("captures", captures.MetadataOnly())
}
func drain(b io.ReadCloser) (io.ReadCloser, []byte) {
all, _ := ioutil.ReadAll(b)
b.Close()
return ioutil.NopCloser(bytes.NewReader(all)), all
}
func emitToDashboard(captures []Capture) {
if dashboardSocket == nil {
return
}
metadatas := make([]CaptureMetadata, len(captures))
for i, capture := range captures {
metadatas[i] = capture.Metadata()
}
dashboardSocket.Emit("captures", metadatas)
}

View file

@ -18,6 +18,7 @@ import (
// Test the reverse proxy handler
func TestProxyHandler(t *testing.T) {
// given
tt := []TestCase{
GetRequest(),
PostRequest(),
@ -25,10 +26,12 @@ func TestProxyHandler(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
service := httptest.NewServer(http.HandlerFunc(tc.service))
capture := httptest.NewServer(proxyHandler(Config{TargetURL: service.URL}))
capture := httptest.NewServer(proxyHandler(service.URL))
// when
resp := tc.request(capture.URL)
// then
tc.test(t, resp)
resp.Body.Close()
@ -88,13 +91,16 @@ func PostRequest() TestCase {
func TestDumpRequest(t *testing.T) {
msg := "hello"
// given
req, err := http.NewRequest(http.MethodPost, "http://localhost:9000/", strings.NewReader(msg))
if err != nil {
t.Errorf("Could not create request: %v", err)
}
// when
body, err := dumpRequest(req)
// then
if err != nil {
t.Errorf("Dump Request error: %v", err)
}
@ -106,14 +112,17 @@ func TestDumpRequest(t *testing.T) {
func TestDumpRequestGzip(t *testing.T) {
msg := "hello"
// given
req, err := http.NewRequest(http.MethodPost, "http://localhost:9000/", strings.NewReader(gzipStr(msg)))
req.Header.Set("Content-Encoding", "gzip")
if err != nil {
t.Errorf("Could not create request: %v", err)
}
// when
body, err := dumpRequest(req)
// then
if err != nil {
t.Errorf("Dump Request Gzip error: %v", err)
}
@ -125,10 +134,13 @@ func TestDumpRequestGzip(t *testing.T) {
func TestDumpResponse(t *testing.T) {
msg := "hello"
// given
res := &http.Response{Body: ioutil.NopCloser(strings.NewReader(msg))}
// when
body, err := dumpResponse(res)
// then
if err != nil {
t.Errorf("Dump Response Error: %v", err)
}
@ -140,14 +152,15 @@ func TestDumpResponse(t *testing.T) {
func TestDumpResponseGzip(t *testing.T) {
msg := "hello"
// make a response
// given
h := make(http.Header)
h.Set("Content-Encoding", "gzip")
res := &http.Response{Header: h, Body: ioutil.NopCloser(strings.NewReader(gzipStr(msg)))}
// dump it
// when
body, err := dumpResponse(res)
// then
if err != nil {
t.Errorf("Dump Response error: %v", err)
}
@ -160,17 +173,21 @@ func TestCaptureIDConcurrence(t *testing.T) {
// This test bothers me
// given
interactions := 1000
// Startup servers
service := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond)
rw.WriteHeader(http.StatusOK)
}))
capture := httptest.NewServer(proxyHandler(Config{TargetURL: service.URL, MaxCaptures: interactions}))
repo := NewCapturesRepository(interactions)
capture := httptest.NewServer(NewRecorder(repo, proxyHandler(service.URL)))
defer service.Close()
defer capture.Close()
// when
// Starts go routines so that captureID is incremented concurrently within proxyHandler()
wg := &sync.WaitGroup{}
wg.Add(interactions)
@ -178,14 +195,20 @@ func TestCaptureIDConcurrence(t *testing.T) {
go func() {
_, err := http.Get(capture.URL)
if err != nil {
t.Errorf("Request Failed: %v", err)
t.Fatalf("Request Failed: %v", err)
}
wg.Done()
}()
}
wg.Wait()
// then
// Tests if captures IDs are sequential
captures := repo.FindAll()
if len(captures) == 0 {
t.Fatalf("No captures found")
}
ids := make([]int, len(captures))
for i := 0; i < len(captures); i++ {
ids[i] = captures[i].ID