~adnano/go-gemini-devel

go-gemini: tofu: Implement PersistentHosts v2 APPLIED

Adnan Maolood: 1
 tofu: Implement PersistentHosts

 2 files changed, 87 insertions(+), 17 deletions(-)
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~adnano/go-gemini-devel/patches/19791/mbox | git am -3
Learn more about email & git

[PATCH go-gemini v2] tofu: Implement PersistentHosts Export this patch

---
This patch adds the struct PersistentHosts which represents a persistent
set of known hosts.

 client.go    |  9 ++---
 tofu/tofu.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++------
 2 files changed, 87 insertions(+), 17 deletions(-)

diff --git a/client.go b/client.go
index ba84245..a45c659 100644
--- a/client.go
+++ b/client.go
@@ -20,8 +20,7 @@ type Client struct {
	// If the returned error is not nil, the certificate will not be trusted
	// and the request will be aborted.
	//
	// For a basic trust on first use implementation, see (*KnownHosts).TOFU
	// in the tofu submodule.
	// See the tofu submodule for an implementation of trust on first use.
	TrustCertificate func(hostname string, cert *x509.Certificate) error

	// Timeout specifies a time limit for requests made by this
@@ -88,8 +87,7 @@ func (c *Client) Do(req *Request) (*Response, error) {
	if c.Timeout != 0 {
		err := conn.SetDeadline(start.Add(c.Timeout))
		if err != nil {
			return nil, fmt.Errorf(
				"failed to set connection deadline: %w", err)
			return nil, fmt.Errorf("failed to set connection deadline: %w", err)
		}
	}

@@ -114,8 +112,7 @@ func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) {

	err := req.Write(w)
	if err != nil {
		return nil, fmt.Errorf(
			"failed to write request data: %w", err)
		return nil, fmt.Errorf("failed to write request: %w", err)
	}

	if err := w.Flush(); err != nil {
diff --git a/tofu/tofu.go b/tofu/tofu.go
index 2ea8ac8..a928be6 100644
--- a/tofu/tofu.go
+++ b/tofu/tofu.go
@@ -27,7 +27,7 @@ type KnownHosts struct {
}

// Add adds a host to the list of known hosts.
func (k *KnownHosts) Add(h Host) error {
func (k *KnownHosts) Add(h Host) {
	k.mu.Lock()
	defer k.mu.Unlock()
	if k.hosts == nil {
@@ -35,7 +35,6 @@ func (k *KnownHosts) Add(h Host) error {
	}

	k.hosts[h.Hostname] = h
	return nil
}

// Lookup returns the known host entry corresponding to the given hostname.
@@ -144,7 +143,7 @@ func (k *KnownHosts) Parse(r io.Reader) error {
// TOFU implements basic trust on first use.
//
// If the host is not on file, it is added to the list.
// If the host on file is expired, it is replaced with the provided host.
// If the host on file is expired, a new entry is added to the list.
// If the fingerprint does not match the one on file, an error is returned.
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
	host := NewHost(hostname, cert.Raw, cert.NotAfter)
@@ -181,9 +180,9 @@ func NewHostWriter(w io.WriteCloser) *HostWriter {
	}
}

// NewHostsFile returns a new host writer that appends to the file at the given path.
// OpenHostsFile returns a new host writer that appends to the file at the given path.
// The file is created if it does not exist.
func NewHostsFile(path string) (*HostWriter, error) {
func OpenHostsFile(path string) (*HostWriter, error) {
	f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
	if err != nil {
		return nil, err
@@ -212,6 +211,83 @@ func (h *HostWriter) Close() error {
	return h.cl.Close()
}

// PersistentHosts represents a persistent set of known hosts.
type PersistentHosts struct {
	hosts  *KnownHosts
	writer *HostWriter
}

// NewPersistentHosts returns a new persistent set of known hosts.
func NewPersistentHosts(hosts *KnownHosts, writer *HostWriter) *PersistentHosts {
	return &PersistentHosts{
		hosts,
		writer,
	}
}

// LoadPersistentHosts loads persistent hosts from the file at the given path.
func LoadPersistentHosts(path string) (*PersistentHosts, error) {
	hosts := &KnownHosts{}
	if err := hosts.Load(path); err != nil {
		return nil, err
	}
	writer, err := OpenHostsFile(path)
	if err != nil {
		return nil, err
	}
	return &PersistentHosts{
		hosts,
		writer,
	}, nil
}

// Add adds a host to the list of known hosts.
// It returns an error if the host could not be persisted.
func (p *PersistentHosts) Add(h Host) error {
	err := p.writer.WriteHost(h)
	if err != nil {
		return fmt.Errorf("failed to persist host: %w", err)
	}
	p.hosts.Add(h)
	return nil
}

// Lookup returns the known host entry corresponding to the given hostname.
func (p *PersistentHosts) Lookup(hostname string) (Host, bool) {
	return p.hosts.Lookup(hostname)
}

// Entries returns the known host entries sorted by hostname.
func (p *PersistentHosts) Entries() []Host {
	return p.hosts.Entries()
}

// TOFU implements trust on first use with a persistent set of known hosts.
//
// If the host is not on file, it is added to the list.
// If the host on file is expired, a new entry is added to the list.
// If the fingerprint does not match the one on file, an error is returned.
func (p *PersistentHosts) TOFU(hostname string, cert *x509.Certificate) error {
	host := NewHost(hostname, cert.Raw, cert.NotAfter)

	knownHost, ok := p.Lookup(hostname)
	if !ok || time.Now().After(knownHost.Expires) {
		return p.Add(host)
	}

	// Check fingerprint
	if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
		return fmt.Errorf("fingerprint for %q does not match", hostname)
	}

	return nil
}

// Close closes the underlying HostWriter.
func (p *PersistentHosts) Close() error {
	return p.writer.Close()
}

// Host represents a host entry with a fingerprint using a certain algorithm.
type Host struct {
	Hostname    string      // hostname
@@ -259,8 +335,7 @@ func (h *Host) UnmarshalText(text []byte) error {

	parts := bytes.Split(text, []byte(" "))
	if len(parts) != 4 {
		return fmt.Errorf(
			"expected the format %q", format)
		return fmt.Errorf("expected the format %q", format)
	}

	if len(parts[0]) == 0 {
@@ -271,8 +346,7 @@ func (h *Host) UnmarshalText(text []byte) error {

	algorithm := string(parts[1])
	if algorithm != "SHA-512" {
		return fmt.Errorf(
			"unsupported algorithm %q", algorithm)
		return fmt.Errorf("unsupported algorithm %q", algorithm)
	}

	h.Algorithm = algorithm
@@ -298,8 +372,7 @@ func (h *Host) UnmarshalText(text []byte) error {

	unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
	if err != nil {
		return fmt.Errorf(
			"invalid unix timestamp: %w", err)
		return fmt.Errorf("invalid unix timestamp: %w", err)
	}

	h.Expires = time.Unix(unix, 0)
-- 
2.30.0