---
Adds a missing cors header to expose the 'Location' header to the
browser
cmd/soju/main.go | 7 +++--
fileupload/fileupload.go | 66 ++++++++++++++++++++++++++++++++++++++--
2 files changed, 67 insertions(+), 6 deletions(-)
diff --git a/cmd/soju/main.go b/cmd/soju/main.go
index 7dfed44..04de78b 100644
--- a/cmd/soju/main.go
+++ b/cmd/soju/main.go
@@ -159,9 +159,10 @@ func main() {
fileUploadHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cfg := srv.Config()
h := fileupload.Handler{
- Uploader: cfg.FileUploader,
- DB: db,
- Auth: cfg.Auth,
+ Uploader: cfg.FileUploader,
+ DB: db,
+ Auth: cfg.Auth,
+ HTTPOrigins: cfg.HTTPOrigins,
}
h.ServeHTTP(w, r)
})
diff --git a/fileupload/fileupload.go b/fileupload/fileupload.go
index cceb4a5..85db0cb 100644
--- a/fileupload/fileupload.go
+++ b/fileupload/fileupload.go
@@ -3,11 +3,14 @@ package fileupload
import (
"crypto/rand"
"encoding/hex"
+ "errors"
"fmt"
"io"
"mime"
"net/http"
+ "net/url"
"path"
+ "path/filepath"
"strings"
"time"
@@ -53,14 +56,71 @@ func New(driver, source string) (Uploader, error) {
}
type Handler struct {
- Uploader Uploader
- Auth auth.Authenticator
- DB database.Database
+ Uploader Uploader
+ Auth auth.Authenticator
+ DB database.Database
+ HTTPOrigins []string
+}
+
+func (h *Handler) handleCors(resp http.ResponseWriter, req *http.Request) error {
+ resp.Header().Add("Access-Control-Allow-Credentials", "true")
+ resp.Header().Add("Access-Control-Allow-Headers", "authorization,content-type")
+ resp.Header().Add("Access-Control-Expose-Headers", "Location")
+
+ if len(h.HTTPOrigins) == 0 {
+ return nil
+ }
+
+ reqOrigin := req.Header.Get("Origin")
+ if reqOrigin == "" {
+ return nil
+ }
+ u, err := url.Parse(reqOrigin)
+ if err != nil {
+ http.Error(resp, "Unauthorized", http.StatusUnauthorized)
+ return fmt.Errorf("invald Origin header: %w", err)
+ }
+ if u.Host == req.Host {
+ resp.Header().Add("Access-Control-Allow-Origin", reqOrigin)
+ return nil
+ }
+
+ for _, origin := range h.HTTPOrigins {
+ if origin == "*" {
+ resp.Header().Add("Access-Control-Allow-Origin", reqOrigin)
+ resp.Header().Add("Vary", "Origin")
+ return nil
+ }
+
+ if strings.EqualFold(origin, reqOrigin) {
+ resp.Header().Add("Access-Control-Allow-Origin", origin)
+ return nil
+ }
+
+ match, err := filepath.Match(origin, reqOrigin)
+ if err != nil {
+ http.Error(resp, "Internal error", http.StatusInternalServerError)
+ return err
+ }
+ if match {
+ resp.Header().Add("Access-Control-Allow-Origin", reqOrigin)
+ resp.Header().Add("Vary", "Origin")
+ return nil
+ }
+ }
+
+ http.Error(resp, "Unauthorized Origin", http.StatusUnauthorized)
+
+ return errors.New("unauthorized origin")
}
func (h *Handler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
resp.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none';")
+ if err := h.handleCors(resp, req); err != nil {
+ return
+ }
+
if h.Uploader == nil {
http.NotFound(resp, req)
return
--
2.43.0