~emersion/soju-dev

This thread contains a patchset. You're looking at the original emails, but you may wish to use the patch review UI. Review patch
1

[PATCH v3] PostgreSQL support

Details
Message ID
<20210917173251.5555-1-hubert@hirtz.pm>
DKIM signature
pass
Download raw message
Patch: +434 -6
---

Sorry, forgot to delete the duplicate FOREIGN KEY line for Network in v2

 cmd/soju/main.go    |   2 +-
 cmd/sojuctl/main.go |   2 +-
 db.go               |  12 ++
 db_postgres.go      | 398 ++++++++++++++++++++++++++++++++++++++++++++
 db_sqlite.go        |   4 +-
 doc/soju.1.scd      |  19 ++-
 go.mod              |   1 +
 go.sum              |   2 +
 8 files changed, 434 insertions(+), 6 deletions(-)
 create mode 100644 db_postgres.go

diff --git a/cmd/soju/main.go b/cmd/soju/main.go
index 9f67182..879dd41 100644
--- a/cmd/soju/main.go
+++ b/cmd/soju/main.go
@@ -61,7 +61,7 @@ func main() {
		cfg.Listen = []string{":6697"}
	}

	db, err := soju.OpenSqliteDB(cfg.SQLDriver, cfg.SQLSource)
	db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
	if err != nil {
		log.Fatalf("failed to open database: %v", err)
	}
diff --git a/cmd/sojuctl/main.go b/cmd/sojuctl/main.go
index 48720b2..d19ccfb 100644
--- a/cmd/sojuctl/main.go
+++ b/cmd/sojuctl/main.go
@@ -43,7 +43,7 @@ func main() {
		cfg = config.Defaults()
	}

	db, err := soju.OpenSqliteDB(cfg.SQLDriver, cfg.SQLSource)
	db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
	if err != nil {
		log.Fatalf("failed to open database: %v", err)
	}
diff --git a/db.go b/db.go
index b28827a..d6fae28 100644
--- a/db.go
+++ b/db.go
@@ -26,6 +26,18 @@ type Database interface {
	StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error
}

func OpenDB(driver, source string) (Database, error) {
	switch driver {
	case "sqlite3":
		return OpenSqliteDB(source)
	case "postgres":
		return OpenPostgresDB(source)
	default:
		return nil, fmt.Errorf("unsupported database driver: %q", driver)
	}

}

type User struct {
	ID       int64
	Username string
diff --git a/db_postgres.go b/db_postgres.go
new file mode 100644
index 0000000..d975c88
--- /dev/null
+++ b/db_postgres.go
@@ -0,0 +1,398 @@
package soju

import (
	"database/sql"
	"fmt"
	"math"
	"strings"
	"time"

	_ "github.com/lib/pq"
)

const postgresFunctions = `
DROP FUNCTION IF EXISTS sojuVersion;
CREATE FUNCTION sojuVersion() RETURNS INTEGER AS $$
DECLARE
	version INTEGER;
BEGIN
	SELECT Config.version INTO version FROM Config;
	RETURN version;
EXCEPTION
	WHEN UNDEFINED_TABLE THEN RETURN 0;
END;
$$ LANGUAGE plpgsql
`

const postgresSchema = `
CREATE TABLE Config (
	id SMALLINT PRIMARY KEY,
	version INTEGER NOT NULL,
	CHECK(id = 1)
);

CREATE TABLE "User" (
	id SERIAL PRIMARY KEY,
	username VARCHAR(255) NOT NULL UNIQUE,
	password VARCHAR(255),
	admin BOOLEAN NOT NULL DEFAULT FALSE,
	realname VARCHAR(255)
);

CREATE TABLE Network (
	id SERIAL PRIMARY KEY,
	name VARCHAR(255),
	"user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
	addr VARCHAR(255) NOT NULL,
	nick VARCHAR(255) NOT NULL,
	username VARCHAR(255),
	realname VARCHAR(255),
	pass VARCHAR(255),
	connect_commands VARCHAR(1023),
	sasl_mechanism VARCHAR(255),
	sasl_plain_username VARCHAR(255),
	sasl_plain_password VARCHAR(255),
	sasl_external_cert BYTEA DEFAULT NULL,
	sasl_external_key BYTEA DEFAULT NULL,
	enabled BOOLEAN NOT NULL DEFAULT TRUE,
	UNIQUE("user", addr, nick),
	UNIQUE("user", name)
);

CREATE TABLE Channel (
	id SERIAL PRIMARY KEY,
	network INTEGER NOT NULL REFERENCES Network(id) ON DELETE CASCADE,
	name VARCHAR(255) NOT NULL,
	key VARCHAR(255),
	detached BOOLEAN NOT NULL DEFAULT FALSE,
	detached_internal_msgid VARCHAR(255),
	relay_detached INTEGER NOT NULL DEFAULT 0,
	reattach_on INTEGER NOT NULL DEFAULT 0,
	detach_after INTEGER NOT NULL DEFAULT 0,
	detach_on INTEGER NOT NULL DEFAULT 0,
	UNIQUE(network, name)
);

CREATE TABLE DeliveryReceipt (
	id SERIAL PRIMARY KEY,
	network INTEGER NOT NULL REFERENCES Network(id) ON DELETE CASCADE,
	target VARCHAR(255) NOT NULL,
	client VARCHAR(255) NOT NULL DEFAULT '',
	internal_msgid VARCHAR(255) NOT NULL,
	UNIQUE(network, target, client)
);
`

var postgresMigrations = []string{
	"", // migration #0 is reserved for schema initialization
}

type PostgresDB struct {
	db *sql.DB
}

func OpenPostgresDB(source string) (Database, error) {
	sqlPostgresDB, err := sql.Open("postgres", source)
	if err != nil {
		return nil, err
	}

	db := &PostgresDB{db: sqlPostgresDB}
	if err := db.upgrade(); err != nil {
		sqlPostgresDB.Close()
		return nil, err
	}

	return db, nil
}

func (db *PostgresDB) Close() error {
	return db.db.Close()
}

func (db *PostgresDB) upgrade() error {
	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.Rollback()

	if _, err := tx.Exec(postgresFunctions); err != nil {
		return fmt.Errorf("failed to install functions: %s", err)
	}

	var version int
	if err := tx.QueryRow("SELECT sojuVersion()").Scan(&version); err != nil {
		return fmt.Errorf("failed to query schema version: %s", err)
	}

	if version == len(postgresMigrations) {
		return nil
	}
	if version > len(postgresMigrations) {
		return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
	}

	if version == 0 {
		if _, err := tx.Exec(postgresSchema); err != nil {
			return fmt.Errorf("failed to initialize schema: %s", err)
		}
	} else {
		for i := version; i < len(postgresMigrations); i++ {
			if _, err := tx.Exec(postgresMigrations[i]); err != nil {
				return fmt.Errorf("failed to execute migration #%v: %v", i, err)
			}
		}
	}

	_, err = tx.Exec(`INSERT INTO Config(id, version) VALUES (1, $1)
		ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
	if err != nil {
		return fmt.Errorf("failed to bump schema version: %v", err)
	}

	return tx.Commit()
}

func (db *PostgresDB) ListUsers() ([]User, error) {
	rows, err := db.db.Query(`SELECT id, username, password, admin FROM "User"`)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var users []User
	for rows.Next() {
		var user User
		var password sql.NullString
		if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
			return nil, err
		}
		user.Password = password.String
		users = append(users, user)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}

	return users, nil
}

func (db *PostgresDB) GetUser(username string) (*User, error) {
	user := &User{Username: username}

	var password, realname sql.NullString
	row := db.db.QueryRow(
		`SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
		username)
	if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
		return nil, err
	}
	user.Password = password.String
	user.Realname = realname.String
	return user, nil
}

func (db *PostgresDB) StoreUser(user *User) error {
	password := toNullString(user.Password)
	realname := toNullString(user.Realname)
	err := db.db.QueryRow(`
		INSERT INTO "User" (username, password, admin, realname)
		VALUES ($1, $2, $3, $4)
		ON CONFLICT (username)
		DO UPDATE SET password = $2, admin = $3, realname = $4
		RETURNING id`,
		user.Username, password, user.Admin, realname).Scan(&user.ID)
	return err
}

func (db *PostgresDB) DeleteUser(id int64) error {
	_, err := db.db.Exec(`DELETE FROM "User" WHERE id = $1`, id)
	return err
}

func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
	rows, err := db.db.Query(`
		SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
		       sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
		FROM Network
		WHERE "user" = $1`, userID)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var networks []Network
	for rows.Next() {
		var net Network
		var name, username, realname, pass, connectCommands sql.NullString
		var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
		err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
			&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
			&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
		if err != nil {
			return nil, err
		}
		net.Name = name.String
		net.Username = username.String
		net.Realname = realname.String
		net.Pass = pass.String
		if connectCommands.Valid {
			net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
		}
		net.SASL.Mechanism = saslMechanism.String
		net.SASL.Plain.Username = saslPlainUsername.String
		net.SASL.Plain.Password = saslPlainPassword.String
		networks = append(networks, net)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}

	return networks, nil
}

func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
	netName := toNullString(network.Name)
	netUsername := toNullString(network.Username)
	realname := toNullString(network.Realname)
	pass := toNullString(network.Pass)
	connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))

	var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
	if network.SASL.Mechanism != "" {
		saslMechanism = toNullString(network.SASL.Mechanism)
		switch network.SASL.Mechanism {
		case "PLAIN":
			saslPlainUsername = toNullString(network.SASL.Plain.Username)
			saslPlainPassword = toNullString(network.SASL.Plain.Password)
			network.SASL.External.CertBlob = nil
			network.SASL.External.PrivKeyBlob = nil
		case "EXTERNAL":
			// keep saslPlain* nil
		default:
			return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
		}
	}

	err := db.db.QueryRow(`
		INSERT INTO Network("user", name, addr, nick, username, realname, pass, connect_commands,
		                    sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
		                    sasl_external_key, enabled)
		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
		ON CONFLICT ("user", name)
		DO UPDATE SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
		              connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
		              sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
		              enabled = $14
		RETURNING id`,
		userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
		saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
		network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
	return err
}

func (db *PostgresDB) DeleteNetwork(id int64) error {
	_, err := db.db.Exec(`DELETE FROM Network WHERE id = $1`, id)
	return err
}

func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
	rows, err := db.db.Query(`
		SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
		       detach_on
		FROM Channel
		WHERE network = $1`, networkID)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var channels []Channel
	for rows.Next() {
		var ch Channel
		var key, detachedInternalMsgID sql.NullString
		var detachAfter int64
		if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
			return nil, err
		}
		ch.Key = key.String
		ch.DetachedInternalMsgID = detachedInternalMsgID.String
		ch.DetachAfter = time.Duration(detachAfter) * time.Second
		channels = append(channels, ch)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}

	return channels, nil
}

func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
	key := toNullString(ch.Key)
	detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
	err := db.db.QueryRow(`
		INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
		                    detach_after, detach_on)
		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
		ON CONFLICT (network, name)
		DO UPDATE SET network = $1, name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
		              relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
		RETURNING id`,
		networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
		ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
	return err
}

func (db *PostgresDB) DeleteChannel(id int64) error {
	_, err := db.db.Exec(`DELETE FROM Channel WHERE id = $1`, id)
	return err
}

func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
	rows, err := db.db.Query(`
		SELECT id, target, client, internal_msgid
		FROM DeliveryReceipt
		WHERE network = $1`, networkID)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var receipts []DeliveryReceipt
	for rows.Next() {
		var rcpt DeliveryReceipt
		if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
			return nil, err
		}
		receipts = append(receipts, rcpt)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}

	return receipts, nil
}

func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
	stmt, err := db.db.Prepare(`
		INSERT INTO DeliveryReceipt(network, target, client, internal_msgid)
		VALUES ($1, $2, $3, $4)
		ON CONFLICT (network, target, client)
		DO UPDATE SET internal_msgid = $4
		RETURNING id`)
	if err != nil {
		return err
	}
	defer stmt.Close()

	// No need for a transaction since all changes are atomic and don't break data coherence.
	for i := range receipts {
		rcpt := &receipts[i]
		err := stmt.QueryRow(networkID, rcpt.Target, client, rcpt.InternalMsgID).Scan(&rcpt.ID)
		if err != nil {
			return err
		}
	}
	return nil
}
diff --git a/db_sqlite.go b/db_sqlite.go
index 7c0840a..1ff872e 100644
--- a/db_sqlite.go
+++ b/db_sqlite.go
@@ -142,8 +142,8 @@ type SqliteDB struct {
	db   *sql.DB
}

func OpenSqliteDB(driver, source string) (Database, error) {
	sqlSqliteDB, err := sql.Open(driver, source)
func OpenSqliteDB(source string) (Database, error) {
	sqlSqliteDB, err := sql.Open("sqlite3", source)
	if err != nil {
		return nil, err
	}
diff --git a/doc/soju.1.scd b/doc/soju.1.scd
index 6838bbc..c3779f0 100644
--- a/doc/soju.1.scd
+++ b/doc/soju.1.scd
@@ -106,8 +106,23 @@ The following directives are supported:
*tls* <cert> <key>
	Enable TLS support. The certificate and the key files must be PEM-encoded.

*db* sqlite3 <path>
	Set the SQLite database path (default: "soju.db" in the current directory).
*db* <driver> <path>
	Set the database location for user, network and channel storage.  By
	default, a _sqlite3_ database is opened in "./soju.db".

	Supported drivers:

	- _sqlite3_ expects <path> to point to the SQLite file
	- _postgres_ expects <path> to be a space-separated list of _key=value_
	  parameters, e.g.

		db postgres "user=soju dbname=soju host=/run/postgres sslmode=disable"

		Please note that _sslmode_ defaults to _require_.

		See the documentation of your version of PostgreSQL for the full list of
		allowed parameters.  The current version can be found here:
		<https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>.

*log* fs <path>
	Path to the bouncer logs root directory, or empty to disable logging. By
diff --git a/go.mod b/go.mod
index 3c3072e..6634d67 100644
--- a/go.mod
+++ b/go.mod
@@ -7,6 +7,7 @@ require (
	git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9
	github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21
	github.com/klauspost/compress v1.13.5 // indirect
	github.com/lib/pq v1.10.3
	github.com/mattn/go-sqlite3 v1.14.8
	github.com/pires/go-proxyproto v0.6.1
	golang.org/x/crypto v0.0.0-20210817164053-32db794688a5
diff --git a/go.sum b/go.sum
index dfce806..0497d01 100644
--- a/go.sum
+++ b/go.sum
@@ -43,6 +43,8 @@ github.com/klauspost/compress v1.13.5 h1:9O69jUPDcsT9fEm74W92rZL9FQY7rCdaXVneq+y
github.com/klauspost/compress v1.13.5/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg=
github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU=
-- 
2.33.0
Details
Message ID
<fGk_7g7azHRTTv-i9iX0gTpv1Jg5E454vEJ78MVgCT9baFDSTmZU51UCClhBLA1MXQoXDSH1QOrlVNfIrLjCfXgNyt-PRAaRxVvNv4PCkmo=@emersion.fr>
In-Reply-To
<20210917173251.5555-1-hubert@hirtz.pm> (view parent)
DKIM signature
pass
Download raw message
This version looks pretty good to me! Just one or two comments below.

Bonus points for using sql.Named, but we can do that as a second step
anyways.

On Friday, September 17th, 2021 at 19:32, Hubert Hirtz <hubert@hirtz.pm> wrote:

> diff --git a/db.go b/db.go
> index b28827a..d6fae28 100644
> --- a/db.go
> +++ b/db.go
> @@ -26,6 +26,18 @@ type Database interface {
>  	StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error
>  }
>
> +func OpenDB(driver, source string) (Database, error) {
> +	switch driver {
> +	case "sqlite3":
> +		return OpenSqliteDB(source)
> +	case "postgres":
> +		return OpenPostgresDB(source)
> +	default:
> +		return nil, fmt.Errorf("unsupported database driver: %q", driver)
> +	}
> +

Style nit: stray newline

> +}
> +
>  type User struct {
>  	ID       int64
>  	Username string
> diff --git a/db_postgres.go b/db_postgres.go
> new file mode 100644
> index 0000000..d975c88
> --- /dev/null
> +++ b/db_postgres.go
> @@ -0,0 +1,398 @@
> +package soju
> +
> +import (
> +	"database/sql"
> +	"fmt"
> +	"math"
> +	"strings"
> +	"time"
> +
> +	_ "github.com/lib/pq"
> +)
> +
> +const postgresFunctions = `
> +DROP FUNCTION IF EXISTS sojuVersion;
> +CREATE FUNCTION sojuVersion() RETURNS INTEGER AS $$
> +DECLARE
> +	version INTEGER;
> +BEGIN
> +	SELECT Config.version INTO version FROM Config;
> +	RETURN version;
> +EXCEPTION
> +	WHEN UNDEFINED_TABLE THEN RETURN 0;
> +END;
> +$$ LANGUAGE plpgsql
> +`

Hm, instead of this, I think we should be able to just try to do the SELECT,
then on error check if it's a *pq.Error [1], then check if Error.Code is set
to "42P01" (undefined_table) [2].

Another alternative would be to use a custom per-database configuration
parameter as described in [3], but a table sounds a bit cleaner.

[1]: https://pkg.go.dev/github.com/lib/pq#Error
[2]: https://www.postgresql.org/docs/13/errcodes-appendix.html
[3]: https://stackoverflow.com/questions/34476062/setting-a-configuration-parameter-for-functions-implemented-in-pl-pgsql

> +const postgresSchema = `
> +CREATE TABLE Config (
> +	id SMALLINT PRIMARY KEY,

Note to self: seems like this is used to ensure the table contains a single
row. Could use ON BEFORE INSERT instead, but not sure it'd be better, I'm fine
with the current approach.
Reply to thread Export thread (mbox)