gildarts: 1 WIP database: client cert storage 3 files changed, 171 insertions(+), 0 deletions(-)
Copy & paste the following snippet into your terminal to import this patchset into git:
curl -s https://lists.sr.ht/~emersion/soju-dev/patches/33843/mbox | git am -3Learn more about email & git
--- Mostly looking for comments on if this is the right approach for storing the client certs in the database. Still need to work on the actual authentication pieces. database/database.go | 10 +++++ database/postgres.go | 72 +++++++++++++++++++++++++++++++++++ database/sqlite.go | 89 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+) diff --git a/database/database.go b/database/database.go index eb5240e..6618fe0 100644 --- a/database/database.go +++ b/database/database.go @@ -20,6 +20,10 @@ type Database interface { StoreUser(ctx context.Context, user *User) error DeleteUser(ctx context.Context, id int64) error + ListClientCerts(ctx context.Context, userID int64) ([]ClientCert, error) + StoreClientCert(ctx context.Context, cert *ClientCert, userID int64) error + DeleteClientCert(ctx context.Context, id int64) error + ListNetworks(ctx context.Context, userID int64) ([]Network, error) StoreNetwork(ctx context.Context, userID int64, network *Network) error DeleteNetwork(ctx context.Context, id int64) error @@ -103,6 +107,12 @@ func (u *User) SetPassword(password string) error { return nil } +type ClientCert struct { + ID int64 + UserID int64 + Fingerprint string +} + type SASL struct { Mechanism string diff --git a/database/postgres.go b/database/postgres.go index 1015b52..64641fd 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -34,6 +34,13 @@ CREATE TABLE "User" ( realname VARCHAR(255) ); +CREATE TABLE "ClientCert" ( + id SERIAL PRIMARY KEY, + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + fingerprint VARCHAR(128), + UNIQUE("user", fingerprint) +); + CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); CREATE TABLE "Network" ( @@ -155,6 +162,14 @@ var postgresMigrations = []string{ REFERENCES "User"(id) ON DELETE CASCADE `, `ALTER TABLE "User" ADD COLUMN nick VARCHAR(255)`, + ` + CREATE TABLE "ClientCert" ( + id SERIAL PRIMARY KEY, + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + fingerprint VARCHAR(128), + UNIQUE("user", fingerprint) + ) + `, } type PostgresDB struct { @@ -361,6 +376,63 @@ func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error { return err } +func (db *PostgresDB) ListClientCerts(ctx context.Context, userID int64) ([]ClientCert, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + `SELECT id, "user", fingerprint FROM "ClientCert" WHERE "user" = $1`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var certs []ClientCert + for rows.Next() { + var cert ClientCert + if err := rows.Scan(&cert.ID, &cert.UserID, &cert.Fingerprint); err != nil { + return nil, err + } + + certs = append(certs, cert) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return certs, nil +} + +func (db *PostgresDB) StoreClientCert(ctx context.Context, cert *ClientCert, userID int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var err error + if cert.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "ClientCert" ("user", fingerprint) + VALUSE ($1, $2) + RETURNING id`, + cert.UserID, cert.Fingerprint).Scan(&cert.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "ClientCert" + SET "user" = $1, fingerprint = $2 + WHERE id = $3`, + userID, cert.Fingerprint, cert.ID) + } + + return err +} + +func (db *PostgresDB) DeleteClientCert(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "ClientCert" WHERE id = $1`, id) + return err +} + func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() diff --git a/database/sqlite.go b/database/sqlite.go index fc70f11..c779e9a 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -31,6 +31,14 @@ CREATE TABLE User ( nick TEXT ); +CREATE TABLE ClientCert ( + id INTEGER PRIMARY KEY, + user INTEGER NOT NULL, + fingerprint TEXT, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, fingerprint) +); + CREATE TABLE Network ( id INTEGER PRIMARY KEY, name TEXT, @@ -245,6 +253,15 @@ var sqliteMigrations = []string{ UPDATE WebPushSubscription AS wps SET user = (SELECT n.user FROM Network AS n WHERE n.id = wps.network); `, "ALTER TABLE User ADD COLUMN nick TEXT;", + ` + CREATE TABLE ClientCert ( + id INTEGER PRIMARY KEY, + user INTEGER NOT NULL, + fingerprint TEXT, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, fingerprint) + ); + `, } type SqliteDB struct { @@ -477,6 +494,11 @@ func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error { return err } + _, err = tx.ExecContext(ctx, "DELETE FROM ClientCert WHERE user = ?", id) + if err != nil { + return err + } + _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id) if err != nil { return err @@ -485,6 +507,73 @@ func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error { return tx.Commit() } +func (db *SqliteDB) ListClientCerts(ctx context.Context, userID int64) ([]ClientCert, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + `SELECT id, user, fingerprint FROM ClientCert WHERE user = ?`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var certs []ClientCert + for rows.Next() { + var cert ClientCert + if err := rows.Scan(&cert.ID, &cert.UserID, &cert.Fingerprint); err != nil { + return nil, err + } + + certs = append(certs, cert) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return certs, nil +} + +func (db *SqliteDB) StoreClientCert(ctx context.Context, cert *ClientCert, userID int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("user", cert.UserID), + sql.Named("fingerprint", cert.Fingerprint), + } + + var err error + if cert.ID == 0 { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO ClientCert(user, fingerprint) + VALUES (:user, :fingerprint)`, + args...) + if err != nil { + return err + } + cert.ID, err = res.LastInsertId() + } else { + args = append(args, sql.Named("id", cert.ID)) + _, err = db.db.ExecContext(ctx, ` + UPDATE ClientCert + SET user = :user, fingerprint = :fingerprint + WHERE id = :id`, + args...) + } + + return err +} + +func (db *SqliteDB) DeleteClientCert(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM ClientCert WHERE id = ?`, id) + return err +} + func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() -- 2.36.1.windows.1