[v3] Fix and optimise assetserver (#4049)

* Fix and optimize content type sniffer

- Minimize copying and buffering
- Ensure it sniffs the full 512-bytes prefix

* Fix assorted warnings

* Cleanup error formatting

- Remove unnecessary formatting calls
- Fix invalid format strings
- Standardise logging calls

* Fix and optimize index fallback method

- Pass through non-404 responses correctly
- Do not buffer original response

* Test content sniffing and index fallback

* Update changelog

* Remove obsolete check

* Add safety checks in sniffer
This commit is contained in:
Fabio Massaioli 2025-02-08 14:02:54 +01:00 committed by GitHub
commit d4096868e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 465 additions and 82 deletions

View file

@ -60,6 +60,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Mac + Mac JS events now fixed by [@leaanthony](https://github.com/leaanthony)
- Fixed event deadlock for macOS by [@leaanthony](https://github.com/leaanthony)
- Fixed a `Parameter incorrect` error in Window initialisation on Windows when HTML provided but no JS by [@leaanthony](https://github.com/leaanthony)
- Fixed size of response prefix used for content type sniffing in asset server by [@fbbdev](https://github.com/fbbdev) in [#4049](https://github.com/wailsapp/wails/pull/4049)
- Fixed handling of non-404 responses on root index path in asset server by [@fbbdev](https://github.com/fbbdev) in [#4049](https://github.com/wailsapp/wails/pull/4049)
### Changed

View file

@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"io/fs"
iofs "io/fs"
"net/http"
"os"
@ -24,7 +23,7 @@ type assetFileServer struct {
err error
}
func newAssetFileServerFS(vfs fs.FS) http.Handler {
func newAssetFileServerFS(vfs iofs.FS) http.Handler {
subDir, err := findPathToFile(vfs, indexHTML)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
@ -34,7 +33,7 @@ func newAssetFileServerFS(vfs fs.FS) http.Handler {
msg += fmt.Sprintf(", please make sure the embedded directory '%s' is correct and contains your assets", rootFolder)
}
err = fmt.Errorf(msg)
err = errors.New(msg)
}
} else {
vfs, err = iofs.Sub(vfs, path.Clean(subDir))

View file

@ -4,7 +4,6 @@ import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"time"
@ -13,7 +12,6 @@ import (
const (
webViewRequestHeaderWindowId = "x-wails-window-id"
webViewRequestHeaderWindowName = "x-wails-window-name"
servicePrefix = "wails/services"
HeaderAcceptLanguage = "accept-language"
)
@ -59,6 +57,11 @@ func NewAssetServer(options *Options) (*AssetServer, error) {
func (a *AssetServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
start := time.Now()
wrapped := &contentTypeSniffer{rw: rw}
defer func() {
if _, err := wrapped.complete(); err != nil {
a.options.Logger.Error("Error writing response data.", "uri", req.RequestURI, "error", err)
}
}()
req = req.WithContext(contextWithLogger(req.Context(), a.options.Logger))
a.handler.ServeHTTP(wrapped, req)
@ -90,32 +93,25 @@ func (a *AssetServer) serveHTTP(rw http.ResponseWriter, req *http.Request, userH
reqPath := req.URL.Path
switch reqPath {
case "", "/", "/index.html":
recorder := httptest.NewRecorder()
userHandler.ServeHTTP(recorder, req)
for k, v := range recorder.Result().Header {
header[k] = v
// Cache the accept-language header
// before passing the request down the chain.
acceptLanguage := req.Header.Get(HeaderAcceptLanguage)
if acceptLanguage == "" {
acceptLanguage = "en"
}
switch recorder.Code {
case http.StatusOK:
a.writeBlob(rw, indexHTML, recorder.Body.Bytes())
case http.StatusNotFound:
// Read the accept-language header
acceptLanguage := req.Header.Get(HeaderAcceptLanguage)
if acceptLanguage == "" {
acceptLanguage = "en"
}
// Set content type for default index.html
header.Set(HeaderContentType, "text/html; charset=utf-8")
a.writeBlob(rw, indexHTML, defaultIndexHTML(acceptLanguage))
default:
rw.WriteHeader(recorder.Code)
wrapped := &fallbackResponseWriter{
rw: rw,
req: req,
fallback: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// Set content type for default index.html
header.Set(HeaderContentType, "text/html; charset=utf-8")
a.writeBlob(rw, indexHTML, defaultIndexHTML(acceptLanguage))
}),
}
userHandler.ServeHTTP(wrapped, req)
default:
// Check if the path matches the keys in the services map
for route, handler := range a.services {
if strings.HasPrefix(reqPath, route) {
@ -125,14 +121,8 @@ func (a *AssetServer) serveHTTP(rw http.ResponseWriter, req *http.Request, userH
}
}
// Check if it can be served by the user-provided handler
if !strings.HasPrefix(reqPath, servicePrefix) {
userHandler.ServeHTTP(rw, req)
return
}
rw.WriteHeader(http.StatusNotFound)
return
// Forward to the user-provided handler
userHandler.ServeHTTP(rw, req)
}
}
@ -146,13 +136,13 @@ func (a *AssetServer) AttachServiceHandler(prefix string, handler http.Handler)
func (a *AssetServer) writeBlob(rw http.ResponseWriter, filename string, blob []byte) {
err := ServeFile(rw, filename, blob)
if err != nil {
a.serveError(rw, err, "Unable to write content %s", filename)
a.serveError(rw, err, "Error writing file content.", "filename", filename)
}
}
func (a *AssetServer) serveError(rw http.ResponseWriter, err error, msg string, args ...interface{}) {
args = append(args, err)
a.options.Logger.Error(msg+":", args...)
args = append(args, "error", err)
a.options.Logger.Error(msg, args...)
rw.WriteHeader(http.StatusInternalServerError)
}
@ -163,7 +153,7 @@ func GetStartURL(userURL string) (string, error) {
// Parse the port
parsedURL, err := url.Parse(devServerURL)
if err != nil {
return "", fmt.Errorf("Error parsing environment variable 'FRONTEND_DEVSERVER_URL`: " + err.Error() + ". Please check your `Taskfile.yml` file")
return "", fmt.Errorf("error parsing environment variable `FRONTEND_DEVSERVER_URL`: %w. Please check your `Taskfile.yml` file", err)
}
port := parsedURL.Port()
if port != "" {
@ -175,7 +165,7 @@ func GetStartURL(userURL string) (string, error) {
if userURL != "" {
parsedURL, err := baseURL.Parse(userURL)
if err != nil {
return "", fmt.Errorf("Error parsing URL: " + err.Error())
return "", fmt.Errorf("error parsing URL: %w", err)
}
startURL = parsedURL.String()

View file

@ -4,7 +4,6 @@ package assetserver
import (
"embed"
_ "embed"
"io"
iofs "io/fs"
)

View file

@ -0,0 +1,244 @@
package assetserver
import (
"fmt"
"log/slog"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
_ "unsafe"
"github.com/google/go-cmp/cmp"
)
func TestContentSniffing(t *testing.T) {
longLead := strings.Repeat(" ", 512-6)
tests := map[string]struct {
Expect string
Status int
Header map[string][]string
Body []string
}{
"/simple": {
Expect: "text/html; charset=utf-8",
Body: []string{"<html><body>Hello!</body></html>"},
},
"/split": {
Expect: "text/html; charset=utf-8",
Body: []string{
"<html><body>Hello!",
"</body></html>",
},
},
"/lead/short/simple": {
Expect: "text/html; charset=utf-8",
Body: []string{
" " + "<html><body>Hello!</body></html>",
},
},
"/lead/short/split": {
Expect: "text/html; charset=utf-8",
Body: []string{
" ",
"<html><body>Hello!</body></html>",
},
},
"/lead/long/simple": {
Expect: "text/html; charset=utf-8",
Body: []string{
longLead + "<html><body>Hello!</body></html>",
},
},
"/lead/long/split": {
Expect: "text/html; charset=utf-8",
Body: []string{
longLead,
"<html><body>Hello!</body></html>",
},
},
"/lead/toolong/simple": {
Expect: "text/plain; charset=utf-8",
Body: []string{
"Hello" + longLead + "<html><body>Hello!</body></html>",
},
},
"/lead/toolong/split": {
Expect: "text/plain; charset=utf-8",
Body: []string{
"Hello" + longLead,
"<html><body>Hello!</body></html>",
},
},
"/header": {
Expect: "text/html; charset=utf-8",
Status: http.StatusForbidden,
Header: map[string][]string{
"X-Custom": {"CustomValue"},
},
Body: []string{"<html><body>Hello!</body></html>"},
},
"/custom": {
Expect: "text/plain;charset=utf-8",
Header: map[string][]string{
"Content-Type": {"text/plain;charset=utf-8"},
},
Body: []string{"<html><body>Hello!</body></html>"},
},
}
srv, err := NewAssetServer(&Options{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
test, ok := tests[r.URL.Path]
if !ok {
w.WriteHeader(http.StatusNotFound)
return
}
for key, values := range test.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
if test.Status != 0 {
w.WriteHeader(test.Status)
}
for _, chunk := range test.Body {
w.Write([]byte(chunk))
}
}),
Logger: slog.Default(),
})
if err != nil {
t.Fatal("AssetServer initialisation failed: ", err)
}
for path, test := range tests {
t.Run(path[1:], func(t *testing.T) {
res := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, path, nil)
if err != nil {
t.Fatal("http.NewRequest failed: ", err)
}
srv.ServeHTTP(res, req)
expectedStatus := http.StatusOK
if test.Status != 0 {
expectedStatus = test.Status
}
if res.Code != expectedStatus {
t.Errorf("Status code mismatch: want %d, got %d", expectedStatus, res.Code)
}
if ct := res.Header().Get("Content-Type"); ct != test.Expect {
t.Errorf("Content type mismatch: want '%s', got '%s'", test.Expect, ct)
}
for key, values := range test.Header {
if diff := cmp.Diff(values, res.Header().Values(key)); diff != "" {
t.Errorf("Header '%s' mismatch (-want +got):\n%s", key, diff)
}
}
if diff := cmp.Diff(strings.Join(test.Body, ""), res.Body.String()); diff != "" {
t.Errorf("Response body mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestIndexFallback(t *testing.T) {
// Paths to try and whether a 404 should trigger a fallback.
paths := map[string]bool{
"": true,
"/": true,
"/index": false,
"/index.html": true,
"/other": false,
}
statuses := []int{
http.StatusOK,
http.StatusNotFound,
http.StatusForbidden,
}
header := map[string][]string{
"X-Custom": {"CustomValue"},
}
body := "<html><body>Hello!</body></html>"
srv, err := NewAssetServer(&Options{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for key, values := range header {
for _, value := range values {
w.Header().Add(key, value)
}
}
status, err := strconv.Atoi(r.URL.Query().Get("status"))
if err == nil && status != 0 && status != http.StatusOK {
w.WriteHeader(status)
}
w.Write([]byte(body))
}),
Logger: slog.Default(),
})
if err != nil {
t.Fatal("AssetServer initialisation failed: ", err)
}
for path, fallback := range paths {
for _, status := range statuses {
key := "<empty path>"
if len(path) > 0 {
key = path[1:]
}
t.Run(fmt.Sprintf("%s/status=%d", key, status), func(t *testing.T) {
res := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s?status=%d", path, status), nil)
if err != nil {
t.Fatal("http.NewRequest failed: ", err)
}
srv.ServeHTTP(res, req)
fallbackTriggered := false
if status == http.StatusNotFound && fallback {
status = http.StatusOK
fallbackTriggered = true
}
if res.Code != status {
t.Errorf("Status code mismatch: want %d, got %d", status, res.Code)
}
if fallbackTriggered {
if cmp.Equal(body, res.Body.String()) {
t.Errorf("Fallback response has the same body as not found response")
}
return
} else {
for key, values := range header {
if diff := cmp.Diff(values, res.Header().Values(key)); diff != "" {
t.Errorf("Header '%s' mismatch (-want +got):\n%s", key, diff)
}
}
if diff := cmp.Diff(body, res.Body.String()); diff != "" {
t.Errorf("Response body mismatch (-want +got):\n%s", diff)
}
}
})
}
}
}

View file

@ -70,17 +70,21 @@ func (a *AssetServer) processWebViewRequestInternal(r webview.Request) {
wrw := r.Response()
defer func() {
if err := wrw.Finish(); err != nil {
a.options.Logger.Error("Error finishing request '%s': %s", uri, err)
a.options.Logger.Error("Error finishing request.", "uri", uri, "error", err)
}
}()
var rw http.ResponseWriter = &contentTypeSniffer{rw: wrw} // Make sure we have a Content-Type sniffer
defer rw.WriteHeader(http.StatusNotImplemented) // This is a NOP when a handler has already written and set the status
rw := &contentTypeSniffer{rw: wrw} // Make sure we have a Content-Type sniffer
defer func() {
if _, err := rw.complete(); err != nil {
a.options.Logger.Error("Error writing response data.", "uri", uri, "error", err)
}
}()
defer rw.WriteHeader(http.StatusNotImplemented) // This is a NOP when a handler has already written and set the status
uri, err = r.URL()
if err != nil {
a.options.Logger.Error(fmt.Sprintf("Error processing request, unable to get URL: %s (HttpResponse=500)", err))
http.Error(rw, err.Error(), http.StatusInternalServerError)
a.webviewRequestErrorHandler(uri, rw, fmt.Errorf("URL: %w", err))
return
}
@ -162,7 +166,7 @@ func (a *AssetServer) webviewRequestErrorHandler(uri string, rw http.ResponseWri
logInfo = strings.Replace(logInfo, fmt.Sprintf("%s://%s", uri.Scheme, uri.Host), "", 1)
}
a.options.Logger.Error("Error processing request (HttpResponse=500)", "details", logInfo, "error", err.Error())
a.options.Logger.Error("Error processing request (HttpResponse=500)", "details", logInfo, "error", err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}

View file

@ -22,9 +22,9 @@ const (
WailsUserAgentValue = "wails.io"
)
var (
assetServerLogger = struct{}{}
)
type assetServerLogger struct{}
var assetServerLoggerKey assetServerLogger
func ServeFile(rw http.ResponseWriter, filename string, blob []byte) error {
header := rw.Header()
@ -45,17 +45,17 @@ func isWebSocket(req *http.Request) bool {
}
func contextWithLogger(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, assetServerLogger, logger)
return context.WithValue(ctx, assetServerLoggerKey, logger)
}
func logInfo(ctx context.Context, message string, args ...interface{}) {
if logger, _ := ctx.Value(assetServerLogger).(*slog.Logger); logger != nil {
if logger, _ := ctx.Value(assetServerLoggerKey).(*slog.Logger); logger != nil {
logger.Info(message, args...)
}
}
func logError(ctx context.Context, message string, args ...interface{}) {
if logger, _ := ctx.Value(assetServerLogger).(*slog.Logger); logger != nil {
if logger, _ := ctx.Value(assetServerLoggerKey).(*slog.Logger); logger != nil {
logger.Error(message, args...)
}
}

View file

@ -5,38 +5,106 @@ import (
)
type contentTypeSniffer struct {
rw http.ResponseWriter
status int
wroteHeader bool
rw http.ResponseWriter
prefix []byte
status int
headerCommitted bool
headerWritten bool
}
func (rw contentTypeSniffer) Header() http.Header {
// Unwrap returns the wrapped [http.ResponseWriter] for use with [http.ResponseController].
func (rw *contentTypeSniffer) Unwrap() http.ResponseWriter {
return rw.rw
}
func (rw *contentTypeSniffer) Header() http.Header {
return rw.rw.Header()
}
func (rw *contentTypeSniffer) Write(buf []byte) (int, error) {
rw.writeHeader(buf)
return rw.rw.Write(buf)
func (rw *contentTypeSniffer) Write(chunk []byte) (int, error) {
if !rw.headerCommitted {
rw.WriteHeader(http.StatusOK)
}
if rw.headerWritten {
return rw.rw.Write(chunk)
}
if len(chunk) == 0 {
return 0, nil
}
// Cut away at most 512 bytes from chunk, and not less than 0.
cut := max(min(len(chunk), 512-len(rw.prefix)), 0)
if cut >= 512 {
// Avoid copying data if a full prefix is available on first non-zero write.
cut = len(chunk)
rw.prefix = chunk
chunk = nil
} else if cut > 0 {
// First write had less than 512 bytes -- copy data to the prefix buffer.
if rw.prefix == nil {
// Preallocate space for the prefix to be used for sniffing.
rw.prefix = make([]byte, 0, 512)
}
rw.prefix = append(rw.prefix, chunk[:cut]...)
chunk = chunk[cut:]
}
if len(rw.prefix) < 512 {
return cut, nil
}
if _, err := rw.complete(); err != nil {
return cut, err
}
n, err := rw.rw.Write(chunk)
return cut + n, err
}
func (rw *contentTypeSniffer) WriteHeader(code int) {
if rw.wroteHeader {
if rw.headerCommitted {
return
}
rw.status = code
rw.rw.WriteHeader(code)
rw.wroteHeader = true
rw.headerCommitted = true
if _, hasType := rw.Header()[HeaderContentType]; hasType {
rw.rw.WriteHeader(rw.status)
rw.headerWritten = true
}
}
func (rw *contentTypeSniffer) writeHeader(b []byte) {
if rw.wroteHeader {
// sniff sniffs the content type from the stored prefix if necessary,
// then writes the header.
func (rw *contentTypeSniffer) sniff() {
if rw.headerWritten || !rw.headerCommitted {
return
}
m := rw.rw.Header()
m := rw.Header()
if _, hasType := m[HeaderContentType]; !hasType {
m.Set(HeaderContentType, http.DetectContentType(b))
m.Set(HeaderContentType, http.DetectContentType(rw.prefix))
}
rw.WriteHeader(http.StatusOK)
rw.rw.WriteHeader(rw.status)
rw.headerWritten = true
}
// complete sniffs the content type if necessary, writes the header
// and sends the data prefix that has been stored for sniffing.
//
// Whoever creates a contentTypeSniffer instance
// is responsible for calling complete after the nested handler has returned.
func (rw *contentTypeSniffer) complete() (n int, err error) {
rw.sniff()
if rw.headerWritten && len(rw.prefix) > 0 {
n, err = rw.rw.Write(rw.prefix)
rw.prefix = nil
}
return
}

View file

@ -0,0 +1,73 @@
package assetserver
import (
"maps"
"net/http"
)
// fallbackResponseWriter wraps a [http.ResponseWriter].
// If the main handler returns status code 404,
// its response is discarded
// and the request is forwarded to the fallback handler.
type fallbackResponseWriter struct {
rw http.ResponseWriter
req *http.Request
fallback http.Handler
header http.Header
headerWritten bool
complete bool
}
// Unwrap returns the wrapped [http.ResponseWriter] for use with [http.ResponseController].
func (fw *fallbackResponseWriter) Unwrap() http.ResponseWriter {
return fw.rw
}
func (fw *fallbackResponseWriter) Header() http.Header {
if fw.header == nil {
// Preserve original header in case we get a 404 response.
fw.header = fw.rw.Header().Clone()
}
return fw.header
}
func (fw *fallbackResponseWriter) Write(chunk []byte) (int, error) {
if fw.complete {
// Fallback triggered, discard further writes.
return len(chunk), nil
}
if !fw.headerWritten {
fw.WriteHeader(http.StatusOK)
}
return fw.rw.Write(chunk)
}
func (fw *fallbackResponseWriter) WriteHeader(statusCode int) {
if fw.headerWritten {
return
}
fw.headerWritten = true
if statusCode == http.StatusNotFound {
// Protect fallback header from external modifications.
if fw.header == nil {
fw.header = fw.rw.Header().Clone()
}
// Invoke fallback handler.
fw.complete = true
fw.fallback.ServeHTTP(fw.rw, fw.req)
return
}
if fw.header != nil {
// Apply headers and forward original map to the main handler.
maps.Copy(fw.rw.Header(), fw.header)
fw.header = fw.rw.Header()
}
fw.rw.WriteHeader(statusCode)
}

View file

@ -12,7 +12,7 @@ import (
// findEmbedRootPath finds the root path in the embed FS. It's the directory which contains all the files.
func findEmbedRootPath(fileSystem embed.FS) (string, error) {
stopErr := fmt.Errorf("files or multiple dirs found")
stopErr := errors.New("files or multiple dirs found")
fPath := ""
err := fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error {

View file

@ -1,7 +1,7 @@
package assetserver
import (
"fmt"
"errors"
"log/slog"
"net/http"
)
@ -31,7 +31,7 @@ type Options struct {
// Validate the options
func (o Options) Validate() error {
if o.Handler == nil && o.Middleware == nil {
return fmt.Errorf("AssetServer options invalid: either Handler or Middleware must be set")
return errors.New("AssetServer options invalid: either Handler or Middleware must be set")
}
return nil

View file

@ -110,6 +110,7 @@ import "C"
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -227,9 +228,9 @@ func (r *requestBodyStreamReader) Read(p []byte) (n int, err error) {
case 0:
return 0, io.EOF
case -1:
return 0, fmt.Errorf("body: stream error")
return 0, errors.New("body: stream error")
case -2:
return 0, fmt.Errorf("body: no stream defined")
return 0, errors.New("body: no stream defined")
case -3:
return 0, io.ErrClosedPipe
default:

View file

@ -3,10 +3,10 @@
package webview
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/wailsapp/go-webview2/pkg/edge"
)
@ -202,15 +202,17 @@ func getHeaders(req *edge.ICoreWebView2WebResourceRequest) (http.Header, error)
}
func combineErrs(errs []error) error {
// TODO use Go1.20 errors.Join
if len(errs) == 0 {
return nil
err := errors.Join(errs...)
if err != nil {
// errors.Join wraps even a single error.
// Check the filtered error list,
// and if it has just one element return it directly.
errs = err.(interface{ Unwrap() []error }).Unwrap()
if len(errs) == 1 {
return errs[0]
}
}
errStrings := make([]string, len(errs))
for i, err := range errs {
errStrings[i] = err.Error()
}
return fmt.Errorf(strings.Join(errStrings, "\n"))
return err
}

View file

@ -4,6 +4,7 @@ package webview
import (
"bytes"
"errors"
"fmt"
"net/http"
"strings"
@ -68,7 +69,7 @@ func (rw *responseWriter) Finish() error {
if code == http.StatusNotModified {
// WebView2 has problems when a request returns a 304 status code and the WebView2 is going to hang for other
// requests including IPC calls.
errs = append(errs, fmt.Errorf("AssetServer returned 304 - StatusNotModified which are going to hang WebView2, changed code to 505 - StatusInternalServerError"))
errs = append(errs, errors.New("AssetServer returned 304 - StatusNotModified which are going to hang WebView2, changed code to 505 - StatusInternalServerError"))
code = http.StatusInternalServerError
}