~emersion/public-inbox

tlstunnel: Add config reloading v2 APPLIED

minus: 1
 Add config reloading

 3 files changed, 152 insertions(+), 27 deletions(-)
#374800 .build.yml success
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~emersion/public-inbox/patches/16082/mbox | git am -3
Learn more about email & git

[PATCH tlstunnel v2] Add config reloading Export this patch

Instead of updating the configuration, we configure a new Server instance and
then migrate Listeners that still exist to it. Open client connections are
left completely untouched.

Closes https://todo.sr.ht/~emersion/tlstunnel/1
---
- Does not die on failed reload anymore.
- Handles to Frontends and Server in Listener are now atomically
  exchanged. It's ugly but it's less prone to forgetting locking than
  using locks.
- The Listener error check on closing the socket is unchanged.

 cmd/tlstunnel/main.go |  53 +++++++++++++++---
 server.go             | 124 +++++++++++++++++++++++++++++++++++-------
 tlstunnel.1.scd       |   2 +
 3 files changed, 152 insertions(+), 27 deletions(-)

diff --git a/cmd/tlstunnel/main.go b/cmd/tlstunnel/main.go
index f4ba7ef..5f04c86 100644
--- a/cmd/tlstunnel/main.go
+++ b/cmd/tlstunnel/main.go
@@ -2,7 +2,11 @@ package main

import (
	"flag"
	"fmt"
	"log"
	"os"
	"os/signal"
	"syscall"

	"git.sr.ht/~emersion/go-scfg"
	"git.sr.ht/~emersion/tlstunnel"
@@ -15,13 +19,10 @@ var (
	certDataPath = ""
)

func main() {
	flag.StringVar(&configPath, "config", configPath, "path to configuration file")
	flag.Parse()

func newServer() (*tlstunnel.Server, error) {
	cfg, err := scfg.Load(configPath)
	if err != nil {
		log.Fatalf("failed to load config file: %v", err)
		return nil, fmt.Errorf("failed to load config file: %w", err)
	}

	srv := tlstunnel.NewServer()
@@ -37,7 +38,7 @@ func main() {
	}
	logger, err := loggerCfg.Build()
	if err != nil {
		log.Fatalf("failed to initialize zap logger: %v", err)
		return nil, fmt.Errorf("failed to initialize zap logger: %w", err)
	}
	srv.ACMEConfig.Logger = logger
	srv.ACMEManager.Logger = logger
@@ -47,12 +48,48 @@ func main() {
	}

	if err := srv.Load(cfg); err != nil {
		log.Fatal(err)
		return nil, err
	}

	return srv, nil
}

func main() {
	flag.StringVar(&configPath, "config", configPath, "path to configuration file")
	flag.Parse()

	srv, err := newServer()
	if err != nil {
		log.Fatalf("failed to create server: %v", err)
	}

	sigCh := make(chan os.Signal, 1)
	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)

	if err := srv.Start(); err != nil {
		log.Fatal(err)
	}

	select {}
	for sig := range sigCh {
		switch sig {
		case syscall.SIGINT:
		case syscall.SIGTERM:
			srv.Stop()
			return
		case syscall.SIGHUP:
			log.Print("caught SIGHUP, reloading config")
			newSrv, err := newServer()
			if err != nil {
				log.Printf("reload failed: %v", err)
				continue
			}
			err = newSrv.Replace(srv)
			if err != nil {
				log.Printf("reload failed: %v", err)
				continue
			}
			srv = newSrv
			log.Print("successfully reloaded config")
		}
	}
}
diff --git a/server.go b/server.go
index c9afe32..b96eb2c 100644
--- a/server.go
+++ b/server.go
@@ -8,6 +8,7 @@ import (
	"log"
	"net"
	"strings"
	"sync/atomic"

	"git.sr.ht/~emersion/go-scfg"
	"github.com/caddyserver/certmagic"
@@ -24,6 +25,8 @@ type Server struct {

	ACMEManager *certmagic.ACMEManager
	ACMEConfig  *certmagic.Config

	cancelACME context.CancelFunc
}

func NewServer() *Server {
@@ -57,17 +60,28 @@ func (srv *Server) RegisterListener(addr string) *Listener {
	return ln
}

func (srv *Server) Start() error {
func (srv *Server) startACME() error {
	var ctx context.Context
	ctx, srv.cancelACME = context.WithCancel(context.Background())

	for _, cert := range srv.UnmanagedCerts {
		if err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
			return err
		}
	}

	if err := srv.ACMEConfig.ManageAsync(context.Background(), srv.ManagedNames); err != nil {
	if err := srv.ACMEConfig.ManageAsync(ctx, srv.ManagedNames); err != nil {
		return fmt.Errorf("failed to manage TLS certificates: %v", err)
	}

	return nil
}

func (srv *Server) Start() error {
	if err := srv.startACME(); err != nil {
		return err
	}

	for _, ln := range srv.Listeners {
		if err := ln.Start(); err != nil {
			return err
@@ -76,37 +90,94 @@ func (srv *Server) Start() error {
	return nil
}

type Listener struct {
	Address   string
func (srv *Server) Stop() {
	srv.cancelACME()
	// TODO: clean cached unmanaged certs
	for _, ln := range srv.Listeners {
		ln.Stop()
	}
}

// Replace starts the server but takes over existing listeners from an old
// Server instance. The old instance keeps running unchanged if Replace
// returns an error.
func (srv *Server) Replace(old *Server) error {
	// Try to start new listeners
	for addr, ln := range srv.Listeners {
		if _, ok := old.Listeners[addr]; ok {
			continue
		}
		if err := ln.Start(); err != nil {
			for _, ln2 := range srv.Listeners {
				ln2.Stop()
			}
			return err
		}
	}

	// Restart ACME
	old.cancelACME()
	if err := srv.startACME(); err != nil {
		for _, ln2 := range srv.Listeners {
			ln2.Stop()
		}
		return err
	}
	// TODO: clean cached unmanaged certs

	// Take over existing listeners and terminate old ones
	for addr, oldLn := range old.Listeners {
		if ln, ok := srv.Listeners[addr]; ok {
			srv.Listeners[addr] = oldLn.UpdateFrom(ln)
		} else {
			oldLn.Stop()
		}
	}

	return nil
}

type listenerHandles struct {
	Server    *Server
	Frontends map[string]*Frontend // indexed by server name
}

type Listener struct {
	Address string
	netLn   net.Listener
	atomic  atomic.Value
}

func newListener(srv *Server, addr string) *Listener {
	return &Listener{
		Address:   addr,
	ln := &Listener{
		Address: addr,
	}
	ln.atomic.Store(&listenerHandles{
		Server:    srv,
		Frontends: make(map[string]*Frontend),
	}
	})
	return ln
}

func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error {
	if _, ok := ln.Frontends[name]; ok {
	fes := ln.atomic.Load().(*listenerHandles).Frontends
	if _, ok := fes[name]; ok {
		return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name)
	}
	ln.Frontends[name] = fe
	fes[name] = fe
	return nil
}

func (ln *Listener) Start() error {
	netLn, err := net.Listen("tcp", ln.Address)
	var err error
	ln.netLn, err = net.Listen("tcp", ln.Address)
	if err != nil {
		return err
	}
	log.Printf("listening on %q", ln.Address)

	go func() {
		if err := ln.serve(netLn); err != nil {
		if err := ln.serve(); err != nil {
			log.Fatalf("listener %q: %v", ln.Address, err)
		}
	}()
@@ -114,10 +185,22 @@ func (ln *Listener) Start() error {
	return nil
}

func (ln *Listener) serve(netLn net.Listener) error {
func (ln *Listener) Stop() {
	ln.netLn.Close()
}

func (ln *Listener) UpdateFrom(new *Listener) *Listener {
	ln.atomic.Store(new.atomic.Load())
	return ln
}

func (ln *Listener) serve() error {
	for {
		conn, err := netLn.Accept()
		if err != nil {
		conn, err := ln.netLn.Accept()
		if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
			// Listening socket has been closed by Stop()
			return nil
		} else if err != nil {
			return fmt.Errorf("failed to accept connection: %v", err)
		}

@@ -131,9 +214,10 @@ func (ln *Listener) serve(netLn net.Listener) error {

func (ln *Listener) handle(conn net.Conn) error {
	defer conn.Close()
	srv := ln.atomic.Load().(*listenerHandles).Server

	// TODO: setup timeouts
	tlsConfig := ln.Server.ACMEConfig.TLSConfig()
	tlsConfig := srv.ACMEConfig.TLSConfig()
	getConfigForClient := tlsConfig.GetConfigForClient
	tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
		// Call previous GetConfigForClient function, if any
@@ -145,7 +229,7 @@ func (ln *Listener) handle(conn net.Conn) error {
				return nil, err
			}
		} else {
			tlsConfig = ln.Server.ACMEConfig.TLSConfig()
			tlsConfig = srv.ACMEConfig.TLSConfig()
		}

		fe, err := ln.matchFrontend(hello.ServerName)
@@ -171,18 +255,20 @@ func (ln *Listener) handle(conn net.Conn) error {
}

func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
	fe, ok := ln.Frontends[serverName]
	fes := ln.atomic.Load().(*listenerHandles).Frontends

	fe, ok := fes[serverName]
	if !ok {
		// Match wildcard certificates, allowing only a single, non-partial
		// wildcard, in the left-most label
		i := strings.IndexByte(serverName, '.')
		// Don't allow wildcards with only a TLD (e.g. *.com)
		if i >= 0 && strings.IndexByte(serverName[i+1:], '.') >= 0 {
			fe, ok = ln.Frontends["*"+serverName[i:]]
			fe, ok = fes["*"+serverName[i:]]
		}
	}
	if !ok {
		fe, ok = ln.Frontends[""]
		fe, ok = fes[""]
	}
	if !ok {
		return nil, fmt.Errorf("can't find frontend for server name %q", serverName)
diff --git a/tlstunnel.1.scd b/tlstunnel.1.scd
index 30ee269..b4c409a 100644
--- a/tlstunnel.1.scd
+++ b/tlstunnel.1.scd
@@ -27,6 +27,8 @@ The config file has one directive per line. Directives have a name, followed
by parameters separated by space characters. Directives may have children in
blocks delimited by "{" and "}". Lines beginning with "#" are comments.

tlstunnel will reload the config file when it receives the HUP signal.

Example:

```
-- 
2.29.2
Pushed, thanks for working on this!