~emersion/public-inbox

soju: Implement upstream SASL EXTERNAL support v1 NEEDS REVISION

fox.cpp: 1
 Implement upstream SASL EXTERNAL support

 3 files changed, 294 insertions(+), 6 deletions(-)
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~emersion/public-inbox/patches/10750/mbox | git am -3
Learn more about email & git
View this thread in the archives

[PATCH soju] Implement upstream SASL EXTERNAL support Export this patch

---
 db.go       |  81 ++++++++++++++++++++--
 service.go  | 189 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 upstream.go |  30 ++++++++-
 3 files changed, 294 insertions(+), 6 deletions(-)

diff --git a/db.go b/db.go
index 20ebb4e..0d9e092 100644
--- a/db.go
+++ b/db.go
@@ -2,6 +2,7 @@ package soju

import (
	"database/sql"
	"errors"
	"fmt"
	"strings"
	"sync"
@@ -21,6 +22,12 @@ type SASL struct {
		Username string
		Password string
	}

	// TLS client certificate authentication.
	External struct {
		CertBlob    []byte
		PrivKeyBlob []byte
	}
}

type Network struct {
@@ -68,6 +75,8 @@ CREATE TABLE Network (
	sasl_mechanism VARCHAR(255),
	sasl_plain_username VARCHAR(255),
	sasl_plain_password VARCHAR(255),
	sasl_external_cert BLOB DEFAULT NULL,
	sasl_external_key BLOB DEFAULT NULL,
	FOREIGN KEY(user) REFERENCES User(username),
	UNIQUE(user, addr, nick)
);
@@ -87,6 +96,8 @@ var migrations = []string{
	"", // migration #0 is reserved for schema initialization
	"ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
	"ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
	"ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
	"ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
}

type DB struct {
@@ -293,6 +304,8 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
		case "PLAIN":
			saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
			saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
		case "EXTERNAL":
			// keep saslPlain* nil
		default:
			return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
		}
@@ -302,18 +315,21 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
	if network.ID != 0 {
		_, err = db.db.Exec(`UPDATE Network
			SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
				sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
				sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?, sasl_external_cert = ?,
                sasl_external_key = ?
			WHERE id = ?`,
			netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
			saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
			saslMechanism, saslPlainUsername, saslPlainPassword, network.ID,
			network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob)
	} else {
		var res sql.Result
		res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
				realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
				sasl_plain_password)
			VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
				sasl_plain_password, sasl_external_cert, sasl_external_key)
			VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
			username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
			saslMechanism, saslPlainUsername, saslPlainPassword)
			saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
			network.SASL.External.PrivKeyBlob)
		if err != nil {
			return err
		}
@@ -374,6 +390,61 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
	return channels, nil
}

func (db *DB) SetSASLMechanism(networkID int64, mech string) error {
	if mech == "" {
		_, err := db.db.Exec(`UPDATE Network SET sasl_mechanism = NULL WHERE id = ?`,
			networkID)
		return err
	}

	_, err := db.db.Exec(`UPDATE Network SET sasl_mechanism = ? WHERE id = ?`,
		mech, networkID)
	return err
}

var ErrNoCertificate = errors.New("no certificate associated with network")

func (db *DB) GetCert(networkID int64) (derBytes, privKey []byte, err error) {
	row := db.db.QueryRow(`SELECT sasl_external_cert, sasl_external_key FROM Network WHERE id = ?`,
		networkID)
	if err := row.Scan(&derBytes, &privKey); err != nil {
		if err == sql.ErrNoRows {
			return nil, nil, ErrNoCertificate
		}
		return nil, nil, err
	}
	if len(derBytes) == 0 {
		return nil, nil, ErrNoCertificate
	}
	return derBytes, privKey, nil
}

func (db *DB) SetCert(networkID int64, cert, privKey []byte) error {
	if cert == nil {
		res, err := db.db.Exec(`UPDATE Network SET sasl_external_cert = NULL, sasl_external_key = NULL WHERE id = ?`,
			networkID)
		if err != nil {
			return err
		}

		affected, err := res.RowsAffected()
		if err != nil {
			return nil // whatever, assume we are good.
		}

		// Be useful for the user and report if no certificate was present in
		// the first place.
		if affected == 0 {
			return errors.New("no certificate set for the network")
		}
		return nil
	}

	_, err := db.db.Exec(`UPDATE Network SET sasl_external_cert = ?, sasl_external_key = ? WHERE id = ?`,
		cert, privKey, networkID)
	return err
}

func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
	db.lock.Lock()
	defer db.lock.Unlock()
diff --git a/service.go b/service.go
index 011d1a2..5404c6a 100644
--- a/service.go
+++ b/service.go
@@ -1,10 +1,24 @@
package soju

import (
	"crypto"
	"crypto/ecdsa"
	"crypto/ed25519"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha1"
	"crypto/sha256"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/hex"
	"errors"
	"flag"
	"fmt"
	"io/ioutil"
	"math/big"
	"strings"
	"time"

	"github.com/google/shlex"
	"golang.org/x/crypto/bcrypt"
@@ -117,6 +131,30 @@ func init() {
					desc:   "delete a network",
					handle: handleServiceNetworkDelete,
				},
				"sasl": {
					usage:  "<name> <mechanism>",
					desc:   "change used SASL mechanism, use 'none' to disable SASL",
					handle: handleServiceNetworkSasl,
				},
			},
		},
		"certfp": {
			children: serviceCommandSet{
				"generate": {
					usage:  "<network name>",
					desc:   "generate a new self-signed certificate with RSA-3072 keypair",
					handle: handleServiceCertfpGenerate,
				},
				"fingerprint": {
					usage:  "<network name>",
					desc:   "show fingerprints of certificate associated with the network",
					handle: handleServiceCertfpFingerprints,
				},
				"reset": {
					usage:  "<network name>",
					desc:   "disable SASL EXTERNAL authentication and remove stored certificate",
					handle: handleServiceCertfpReset,
				},
			},
		},
		"change-password": {
@@ -127,6 +165,135 @@ func init() {
	}
}

func handleServiceCertfpGenerate(dc *downstreamConn, params []string) error {
	fs := newFlagSet()
	keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)")
	bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA")

	if err := fs.Parse(params); err != nil {
		return err
	}

	if len(fs.Args()) != 1 {
		return errors.New("exactly one argument is required")
	}

	net := dc.user.getNetwork(fs.Arg(0))
	if net == nil {
		return fmt.Errorf("unknown network %q", fs.Arg(0))
	}

	var (
		privKey crypto.PrivateKey
		pubKey  crypto.PublicKey
	)
	switch *keyType {
	case "rsa":
		key, err := rsa.GenerateKey(rand.Reader, *bits)
		if err != nil {
			return err
		}
		privKey = key
		pubKey = key.Public()
	case "ecdsa":
		key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
		if err != nil {
			return err
		}
		privKey = key
		pubKey = key.Public()
	case "ed25519":
		var err error
		pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
		if err != nil {
			return err
		}
	}

	// Using PKCS#8 allows easier extension for new key types.
	privKeyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
	if err != nil {
		return err
	}

	notBefore := time.Now()
	// Lets make a fair assumption nobody will use the same cert for more than 20 years...
	notAfter := notBefore.Add(24 * time.Hour * 365 * 20)
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	if err != nil {
		return err
	}
	cert := &x509.Certificate{
		SerialNumber: serialNumber,
		Subject:      pkix.Name{CommonName: "soju auto-generated certificate"},
		NotBefore:    notBefore,
		NotAfter:     notAfter,
		KeyUsage:     x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
	}
	derBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey)
	if err != nil {
		return err
	}

	err = dc.srv.db.SetCert(net.ID, derBytes, privKeyBytes)
	if err != nil {
		return err
	}

	sendServicePRIVMSG(dc, "certificate generated")

	sha1Sum := sha1.Sum(derBytes)
	sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:]))
	sha256Sum := sha256.Sum256(derBytes)
	sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:]))

	return dc.srv.db.SetSASLMechanism(net.ID, "EXTERNAL")
}

func handleServiceCertfpFingerprints(dc *downstreamConn, params []string) error {
	if len(params) != 1 {
		return fmt.Errorf("expected exactly one argument")
	}

	net := dc.user.getNetwork(params[0])
	if net == nil {
		return fmt.Errorf("unknown network %q", params[0])
	}

	derBytes, _, err := dc.srv.db.GetCert(net.ID)
	if err != nil {
		return err
	}

	sha1Sum := sha1.Sum(derBytes)
	sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:]))
	sha256Sum := sha256.Sum256(derBytes)
	sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:]))
	return nil
}

func handleServiceCertfpReset(dc *downstreamConn, params []string) error {
	if len(params) != 1 {
		return fmt.Errorf("expected exactly one argument")
	}

	net := dc.user.getNetwork(params[0])
	if net == nil {
		return fmt.Errorf("unknown network %q", params[0])
	}

	if err := dc.srv.db.SetCert(net.ID, nil, nil); err != nil {
		return err
	}

	if net.SASL.Mechanism == "EXTERNAL" {
		return dc.srv.db.SetSASLMechanism(net.ID, "")
	}
	return nil
}

func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, l *[]string) {
	for name, cmd := range cmds {
		words := append(prefix, name)
@@ -295,6 +462,28 @@ func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
	return nil
}

func handleServiceNetworkSasl(dc *downstreamConn, params []string) error {
	if len(params) != 2 {
		return fmt.Errorf("expected exactly two arguments")
	}

	net := dc.user.getNetwork(params[0])
	if net == nil {
		return fmt.Errorf("unknown network %q", params[0])
	}

	mech := strings.ToUpper(params[1])
	if mech == "NONE" {
		mech = ""
	}
	if err := dc.srv.db.SetSASLMechanism(net.ID, mech); err != nil {
		return err
	}

	sendServicePRIVMSG(dc, fmt.Sprintf("changed SASL mechanism for %q to %s", net.GetName(), params[1]))
	return nil
}

func handlePasswordChange(dc *downstreamConn, params []string) error {
	if len(params) != 1 {
		return fmt.Errorf("expected exactly one argument")
diff --git a/upstream.go b/upstream.go
index 1f281fe..e082ac4 100644
--- a/upstream.go
+++ b/upstream.go
@@ -1,7 +1,10 @@
package soju

import (
	"crypto"
	"crypto/sha256"
	"crypto/tls"
	"crypto/x509"
	"encoding/base64"
	"errors"
	"fmt"
@@ -100,7 +103,29 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
		}

		logger.Printf("connecting to TLS server at address %q", addr)
		netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, nil)

		var cfg *tls.Config
		if network.SASL.Mechanism == "EXTERNAL" {
			derBytes, privKey, err := network.user.srv.db.GetCert(network.ID)
			if err != nil {
				return nil, fmt.Errorf("failed to fetch certificate: %v", err)
			}
			key, err := x509.ParsePKCS8PrivateKey(privKey)
			if err != nil {
				return nil, fmt.Errorf("failed to parse private key: %v", err)
			}
			cfg = &tls.Config{
				Certificates: []tls.Certificate{
					{
						Certificate: [][]byte{derBytes},
						PrivateKey:  key.(crypto.PrivateKey),
					},
				},
			}
			logger.Printf("using TLS client certificate %x", sha256.Sum256(derBytes))
		}

		netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, cfg)
	case "irc+insecure":
		if !strings.ContainsRune(addr, ':') {
			addr = addr + ":6667"
@@ -1315,6 +1340,9 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
		case "PLAIN":
			uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
			uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
		case "EXTERNAL":
			uc.logger.Printf("starting SASL EXTERNAL authentication")
			uc.saslClient = sasl.NewExternalClient("")
		default:
			return fmt.Errorf("unsupported SASL mechanism %q", name)
		}
-- 
2.26.2
Thanks for your patch! Here are some comments below.