---
This patch adds the struct PersistentHosts which represents a persistent
set of known hosts.
client.go | 9 ++---
tofu/tofu.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++------
2 files changed, 87 insertions(+), 17 deletions(-)
diff --git a/client.go b/client.go
index ba84245..a45c659 100644
--- a/client.go
+++ b/client.go
@@ -20,8 +20,7 @@ type Client struct {
// If the returned error is not nil, the certificate will not be trusted
// and the request will be aborted.
//
- // For a basic trust on first use implementation, see (*KnownHosts).TOFU
- // in the tofu submodule.
+ // See the tofu submodule for an implementation of trust on first use.
TrustCertificate func(hostname string, cert *x509.Certificate) error
// Timeout specifies a time limit for requests made by this
@@ -88,8 +87,7 @@ func (c *Client) Do(req *Request) (*Response, error) {
if c.Timeout != 0 {
err := conn.SetDeadline(start.Add(c.Timeout))
if err != nil {
- return nil, fmt.Errorf(
- "failed to set connection deadline: %w", err)
+ return nil, fmt.Errorf("failed to set connection deadline: %w", err)
}
}
@@ -114,8 +112,7 @@ func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) {
err := req.Write(w)
if err != nil {
- return nil, fmt.Errorf(
- "failed to write request data: %w", err)
+ return nil, fmt.Errorf("failed to write request: %w", err)
}
if err := w.Flush(); err != nil {
diff --git a/tofu/tofu.go b/tofu/tofu.go
index 2ea8ac8..a928be6 100644
--- a/tofu/tofu.go
+++ b/tofu/tofu.go
@@ -27,7 +27,7 @@ type KnownHosts struct {
}
// Add adds a host to the list of known hosts.
-func (k *KnownHosts) Add(h Host) error {
+func (k *KnownHosts) Add(h Host) {
k.mu.Lock()
defer k.mu.Unlock()
if k.hosts == nil {
@@ -35,7 +35,6 @@ func (k *KnownHosts) Add(h Host) error {
}
k.hosts[h.Hostname] = h
- return nil
}
// Lookup returns the known host entry corresponding to the given hostname.
@@ -144,7 +143,7 @@ func (k *KnownHosts) Parse(r io.Reader) error {
// TOFU implements basic trust on first use.
//
// If the host is not on file, it is added to the list.
-// If the host on file is expired, it is replaced with the provided host.
+// If the host on file is expired, a new entry is added to the list.
// If the fingerprint does not match the one on file, an error is returned.
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
host := NewHost(hostname, cert.Raw, cert.NotAfter)
@@ -181,9 +180,9 @@ func NewHostWriter(w io.WriteCloser) *HostWriter {
}
}
-// NewHostsFile returns a new host writer that appends to the file at the given path.
+// OpenHostsFile returns a new host writer that appends to the file at the given path.
// The file is created if it does not exist.
-func NewHostsFile(path string) (*HostWriter, error) {
+func OpenHostsFile(path string) (*HostWriter, error) {
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return nil, err
@@ -212,6 +211,83 @@ func (h *HostWriter) Close() error {
return h.cl.Close()
}
+// PersistentHosts represents a persistent set of known hosts.
+type PersistentHosts struct {
+ hosts *KnownHosts
+ writer *HostWriter
+}
+
+// NewPersistentHosts returns a new persistent set of known hosts.
+func NewPersistentHosts(hosts *KnownHosts, writer *HostWriter) *PersistentHosts {
+ return &PersistentHosts{
+ hosts,
+ writer,
+ }
+}
+
+// LoadPersistentHosts loads persistent hosts from the file at the given path.
+func LoadPersistentHosts(path string) (*PersistentHosts, error) {
+ hosts := &KnownHosts{}
+ if err := hosts.Load(path); err != nil {
+ return nil, err
+ }
+ writer, err := OpenHostsFile(path)
+ if err != nil {
+ return nil, err
+ }
+ return &PersistentHosts{
+ hosts,
+ writer,
+ }, nil
+}
+
+// Add adds a host to the list of known hosts.
+// It returns an error if the host could not be persisted.
+func (p *PersistentHosts) Add(h Host) error {
+ err := p.writer.WriteHost(h)
+ if err != nil {
+ return fmt.Errorf("failed to persist host: %w", err)
+ }
+ p.hosts.Add(h)
+ return nil
+}
+
+// Lookup returns the known host entry corresponding to the given hostname.
+func (p *PersistentHosts) Lookup(hostname string) (Host, bool) {
+ return p.hosts.Lookup(hostname)
+}
+
+// Entries returns the known host entries sorted by hostname.
+func (p *PersistentHosts) Entries() []Host {
+ return p.hosts.Entries()
+}
+
+// TOFU implements trust on first use with a persistent set of known hosts.
+//
+// If the host is not on file, it is added to the list.
+// If the host on file is expired, a new entry is added to the list.
+// If the fingerprint does not match the one on file, an error is returned.
+func (p *PersistentHosts) TOFU(hostname string, cert *x509.Certificate) error {
+ host := NewHost(hostname, cert.Raw, cert.NotAfter)
+
+ knownHost, ok := p.Lookup(hostname)
+ if !ok || time.Now().After(knownHost.Expires) {
+ return p.Add(host)
+ }
+
+ // Check fingerprint
+ if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
+ return fmt.Errorf("fingerprint for %q does not match", hostname)
+ }
+
+ return nil
+}
+
+// Close closes the underlying HostWriter.
+func (p *PersistentHosts) Close() error {
+ return p.writer.Close()
+}
+
// Host represents a host entry with a fingerprint using a certain algorithm.
type Host struct {
Hostname string // hostname
@@ -259,8 +335,7 @@ func (h *Host) UnmarshalText(text []byte) error {
parts := bytes.Split(text, []byte(" "))
if len(parts) != 4 {
- return fmt.Errorf(
- "expected the format %q", format)
+ return fmt.Errorf("expected the format %q", format)
}
if len(parts[0]) == 0 {
@@ -271,8 +346,7 @@ func (h *Host) UnmarshalText(text []byte) error {
algorithm := string(parts[1])
if algorithm != "SHA-512" {
- return fmt.Errorf(
- "unsupported algorithm %q", algorithm)
+ return fmt.Errorf("unsupported algorithm %q", algorithm)
}
h.Algorithm = algorithm
@@ -298,8 +372,7 @@ func (h *Host) UnmarshalText(text []byte) error {
unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
if err != nil {
- return fmt.Errorf(
- "invalid unix timestamp: %w", err)
+ return fmt.Errorf("invalid unix timestamp: %w", err)
}
h.Expires = time.Unix(unix, 0)
--
2.30.0