change reverse proxy core
This commit is contained in:
parent
56a16fe526
commit
0d73abcefd
138
main.go
138
main.go
|
@ -16,30 +16,21 @@ import (
|
||||||
"github.com/googollee/go-socket.io"
|
"github.com/googollee/go-socket.io"
|
||||||
)
|
)
|
||||||
|
|
||||||
type transport struct {
|
|
||||||
http.RoundTripper
|
|
||||||
maxItems int
|
|
||||||
currItemID int
|
|
||||||
}
|
|
||||||
|
|
||||||
var captures Captures
|
var captures Captures
|
||||||
|
|
||||||
var dashboardSocket socketio.Socket
|
var dashboardSocket socketio.Socket
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
config := ReadConfig()
|
config := ReadConfig()
|
||||||
|
startCapture(config)
|
||||||
|
}
|
||||||
|
|
||||||
transp := &transport{
|
func startCapture(config Config) {
|
||||||
RoundTripper: http.DefaultTransport,
|
http.Handle("/", proxyHandler(config))
|
||||||
maxItems: config.MaxCaptures,
|
http.Handle("/socket.io/", dashboardSocketHandler(config))
|
||||||
currItemID: 0,
|
http.Handle(config.DashboardPath, dashboardHandler())
|
||||||
}
|
http.Handle(config.DashboardClearPath, dashboardClearHandler())
|
||||||
|
http.Handle(config.DashboardItemInfoPath, dashboardItemInfoHandler())
|
||||||
http.Handle("/", getProxyHandler(config.TargetURL, transp))
|
|
||||||
http.Handle("/socket.io/", getDashboardSocketHandler(config))
|
|
||||||
http.Handle(config.DashboardPath, getDashboardHandler())
|
|
||||||
http.Handle(config.DashboardClearPath, getDashboardClearHandler())
|
|
||||||
http.Handle(config.DashboardItemInfoPath, getDashboardItemInfoHandler())
|
|
||||||
|
|
||||||
proxyHost := fmt.Sprintf("http://localhost:%s", config.ProxyPort)
|
proxyHost := fmt.Sprintf("http://localhost:%s", config.ProxyPort)
|
||||||
|
|
||||||
|
@ -49,7 +40,7 @@ func main() {
|
||||||
fmt.Println(http.ListenAndServe(":"+config.ProxyPort, nil))
|
fmt.Println(http.ListenAndServe(":"+config.ProxyPort, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDashboardSocketHandler(config Config) http.Handler {
|
func dashboardSocketHandler(config Config) http.Handler {
|
||||||
server, err := socketio.NewServer(nil)
|
server, err := socketio.NewServer(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("socket server error", err)
|
fmt.Println("socket server error", err)
|
||||||
|
@ -65,7 +56,7 @@ func getDashboardSocketHandler(config Config) http.Handler {
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDashboardClearHandler() http.Handler {
|
func dashboardClearHandler() http.Handler {
|
||||||
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||||
captures = nil
|
captures = nil
|
||||||
emitToDashboard(captures)
|
emitToDashboard(captures)
|
||||||
|
@ -73,14 +64,14 @@ func getDashboardClearHandler() http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDashboardHandler() http.Handler {
|
func dashboardHandler() http.Handler {
|
||||||
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||||
res.Header().Add("Content-Type", "text/html")
|
res.Header().Add("Content-Type", "text/html")
|
||||||
res.Write([]byte(dashboardHTML))
|
res.Write([]byte(dashboardHTML))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDashboardItemInfoHandler() http.Handler {
|
func dashboardItemInfoHandler() http.Handler {
|
||||||
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||||
idStr := req.URL.Path[strings.LastIndex(req.URL.Path, "/")+1:]
|
idStr := req.URL.Path[strings.LastIndex(req.URL.Path, "/")+1:]
|
||||||
idInt, _ := strconv.Atoi(idStr)
|
idInt, _ := strconv.Atoi(idStr)
|
||||||
|
@ -94,67 +85,70 @@ func getDashboardItemInfoHandler() http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProxyHandler(targetURL string, transp *transport) http.Handler {
|
func proxyHandler(config Config) http.Handler {
|
||||||
url, _ := url.Parse(targetURL)
|
url, _ := url.Parse(config.TargetURL)
|
||||||
proxy := httputil.NewSingleHostReverseProxy(url)
|
captureID := 0
|
||||||
proxy.Transport = transp
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
req.Host = url.Host
|
||||||
request.Host = request.URL.Host
|
req.URL.Host = url.Host
|
||||||
proxy.ServeHTTP(response, request)
|
req.URL.Scheme = url.Scheme
|
||||||
|
|
||||||
|
reqDump, err := dumpRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Could not dump request: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := httputil.NewSingleHostReverseProxy(url)
|
||||||
|
proxy.ModifyResponse = func(res *http.Response) error {
|
||||||
|
resDump, err := dumpResponse(res)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not dump response: %v", err)
|
||||||
|
}
|
||||||
|
captureID++
|
||||||
|
capture := Capture{
|
||||||
|
ID: captureID,
|
||||||
|
Path: req.URL.Path,
|
||||||
|
Method: req.Method,
|
||||||
|
Status: res.StatusCode,
|
||||||
|
Request: string(reqDump),
|
||||||
|
Response: string(resDump),
|
||||||
|
}
|
||||||
|
captures.Add(capture)
|
||||||
|
captures.RemoveLastAfterReaching(config.MaxCaptures)
|
||||||
|
emitToDashboard(captures)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
fmt.Printf("uh oh | %v | %s\n", err, req.URL)
|
||||||
|
}
|
||||||
|
proxy.ServeHTTP(rw, req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
|
|
||||||
reqDump, err := dumpRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := t.RoundTripper.RoundTrip(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("uh oh | %v | %s", err, req.URL)
|
|
||||||
}
|
|
||||||
|
|
||||||
resDump, err := dumpResponse(res)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
capture := Capture{
|
|
||||||
ID: t.NewItemID(),
|
|
||||||
Path: req.URL.Path,
|
|
||||||
Method: req.Method,
|
|
||||||
Status: res.StatusCode,
|
|
||||||
Request: string(reqDump),
|
|
||||||
Response: string(resDump),
|
|
||||||
}
|
|
||||||
|
|
||||||
captures.Add(capture)
|
|
||||||
captures.RemoveLastAfterReaching(t.maxItems)
|
|
||||||
emitToDashboard(captures)
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *transport) NewItemID() int {
|
|
||||||
t.currItemID++
|
|
||||||
return t.currItemID
|
|
||||||
}
|
|
||||||
|
|
||||||
func dumpRequest(req *http.Request) ([]byte, error) {
|
func dumpRequest(req *http.Request) ([]byte, error) {
|
||||||
|
if req.Header.Get("Content-Encoding") == "gzip" {
|
||||||
|
var originalBody bytes.Buffer
|
||||||
|
tee := io.TeeReader(req.Body, &originalBody)
|
||||||
|
reader, _ := gzip.NewReader(tee)
|
||||||
|
req.Body = ioutil.NopCloser(reader)
|
||||||
|
reqDump, err := httputil.DumpRequest(req, true)
|
||||||
|
req.Body = ioutil.NopCloser(&originalBody)
|
||||||
|
return reqDump, err
|
||||||
|
}
|
||||||
return httputil.DumpRequest(req, true)
|
return httputil.DumpRequest(req, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dumpResponse(res *http.Response) ([]byte, error) {
|
func dumpResponse(res *http.Response) ([]byte, error) {
|
||||||
var originalBody bytes.Buffer
|
|
||||||
reader := io.TeeReader(res.Body, &originalBody)
|
|
||||||
if res.Header.Get("Content-Encoding") == "gzip" {
|
if res.Header.Get("Content-Encoding") == "gzip" {
|
||||||
reader, _ = gzip.NewReader(reader)
|
var originalBody bytes.Buffer
|
||||||
|
tee := io.TeeReader(res.Body, &originalBody)
|
||||||
|
reader, _ := gzip.NewReader(tee)
|
||||||
|
res.Body = ioutil.NopCloser(reader)
|
||||||
|
resDump, err := httputil.DumpResponse(res, true)
|
||||||
|
res.Body = ioutil.NopCloser(&originalBody)
|
||||||
|
return resDump, err
|
||||||
}
|
}
|
||||||
res.Body = ioutil.NopCloser(reader)
|
return httputil.DumpResponse(res, true)
|
||||||
resDump, err := httputil.DumpResponse(res, true)
|
|
||||||
res.Body = ioutil.NopCloser(&originalBody)
|
|
||||||
return resDump, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func emitToDashboard(captures Captures) {
|
func emitToDashboard(captures Captures) {
|
||||||
|
|
161
main_test.go
Normal file
161
main_test.go
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test the reverse proxy handler
|
||||||
|
func TestProxyHandler(t *testing.T) {
|
||||||
|
tt := []TestCase{
|
||||||
|
GetRequest(),
|
||||||
|
PostRequest(),
|
||||||
|
}
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
service := httptest.NewServer(http.HandlerFunc(tc.service))
|
||||||
|
capture := httptest.NewServer(proxyHandler(Config{TargetURL: service.URL}))
|
||||||
|
|
||||||
|
resp := tc.request(capture.URL)
|
||||||
|
|
||||||
|
tc.test(t, resp)
|
||||||
|
|
||||||
|
resp.Body.Close()
|
||||||
|
capture.Close()
|
||||||
|
service.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
name string
|
||||||
|
request func(string) *http.Response
|
||||||
|
service func(http.ResponseWriter, *http.Request)
|
||||||
|
test func(*testing.T, *http.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRequest() TestCase {
|
||||||
|
msg := "hello"
|
||||||
|
return TestCase{
|
||||||
|
name: "GetRequest",
|
||||||
|
request: func(url string) *http.Response {
|
||||||
|
res, _ := http.Get(url)
|
||||||
|
return res
|
||||||
|
},
|
||||||
|
service: func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
fmt.Fprint(rw, string(msg))
|
||||||
|
},
|
||||||
|
test: func(t *testing.T, res *http.Response) {
|
||||||
|
body, _ := ioutil.ReadAll(res.Body)
|
||||||
|
if string(body) != msg {
|
||||||
|
t.Error("Wrong Body Response")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func PostRequest() TestCase {
|
||||||
|
msg := "hello"
|
||||||
|
return TestCase{
|
||||||
|
name: "PostRequest",
|
||||||
|
request: func(url string) *http.Response {
|
||||||
|
res, _ := http.Post(url, "text/plain", strings.NewReader(msg))
|
||||||
|
return res
|
||||||
|
},
|
||||||
|
service: func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
io.Copy(rw, req.Body)
|
||||||
|
},
|
||||||
|
test: func(t *testing.T, res *http.Response) {
|
||||||
|
body, _ := ioutil.ReadAll(res.Body)
|
||||||
|
if string(body) != msg {
|
||||||
|
t.Error("Wrong Body Response")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDumpRequest(t *testing.T) {
|
||||||
|
msg := "hello"
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "http://localhost:9000/", strings.NewReader(msg))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := dumpRequest(req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Dump Request error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(body), msg) {
|
||||||
|
t.Errorf("Dump Request is not '%s'", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDumpRequestGzip(t *testing.T) {
|
||||||
|
msg := "hello"
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "http://localhost:9000/", strings.NewReader(gzipStr(msg)))
|
||||||
|
req.Header.Set("Content-Encoding", "gzip")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := dumpRequest(req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Dump Request Gzip error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(body), msg) {
|
||||||
|
t.Errorf("Dump Request Gzip is not '%s'", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDumpResponse(t *testing.T) {
|
||||||
|
msg := "hello"
|
||||||
|
|
||||||
|
res := &http.Response{Body: ioutil.NopCloser(strings.NewReader(msg))}
|
||||||
|
|
||||||
|
body, err := dumpResponse(res)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Dump Response Error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(body), msg) {
|
||||||
|
t.Errorf("Dump Response is not '%s'", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDumpResponseGzip(t *testing.T) {
|
||||||
|
msg := "hello"
|
||||||
|
|
||||||
|
// make a response
|
||||||
|
h := make(http.Header)
|
||||||
|
h.Set("Content-Encoding", "gzip")
|
||||||
|
res := &http.Response{Header: h, Body: ioutil.NopCloser(strings.NewReader(gzipStr(msg)))}
|
||||||
|
|
||||||
|
// dump it
|
||||||
|
body, err := dumpResponse(res)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Dump Response error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(body), msg) {
|
||||||
|
t.Error("Not hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func gzipStr(str string) string {
|
||||||
|
var buff bytes.Buffer
|
||||||
|
g := gzip.NewWriter(&buff)
|
||||||
|
io.WriteString(g, str)
|
||||||
|
g.Close()
|
||||||
|
return buff.String()
|
||||||
|
}
|
Loading…
Reference in a new issue