~emersion/soju-dev

Make max-user-networks a limit of enabled networks v1 PROPOSED

delthas: 2
 Make max-user-networks a limit of enabled networks
 Add per-user max-networks limit

 12 files changed, 141 insertions(+), 37 deletions(-)
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~emersion/soju-dev/patches/53205/mbox | git am -3
Learn more about email & git

[PATCH 1/2] Make max-user-networks a limit of enabled networks Export this patch

---
 doc/soju.1.scd |  3 ++-
 downstream.go  | 19 +++++++++++++++++++
 service.go     | 13 +++++++++++++
 user.go        | 29 +++++++++++++++++++++++++----
 4 files changed, 59 insertions(+), 5 deletions(-)

diff --git a/doc/soju.1.scd b/doc/soju.1.scd
index f06d5f8..9ee7824 100644
--- a/doc/soju.1.scd
+++ b/doc/soju.1.scd
@@ -189,7 +189,8 @@ The following directives are supported:
	By default, all IPs are rejected.

*max-user-networks* <limit>
	Maximum number of networks per user. By default, there is no limit.
	Maximum number of enabled networks per user. By default, there is no limit.
	The limit is ignored for admin users or when updating networks as an admin.

*motd* <path>
	Path to the MOTD file. The bouncer MOTD is sent to clients which aren't
diff --git a/downstream.go b/downstream.go
index f48653e..88eddcb 100644
--- a/downstream.go
+++ b/downstream.go
@@ -3041,6 +3041,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
				record.Realname = ""
			}

			if record.Enabled {
				if err := dc.user.canEnableNewNetwork(ctx); err != nil {
					return ircError{&irc.Message{
						Command: "FAIL",
						Params:  []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)},
					}}
				}
			}

			network, err := dc.user.createNetwork(ctx, record)
			if err != nil {
				return ircError{&irc.Message{
@@ -3073,6 +3082,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
			}

			record := net.Network // copy network record because we'll mutate it
			wasEnabled := record.Enabled
			if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil {
				return err
			}
@@ -3084,6 +3094,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
				record.Realname = ""
			}

			if !wasEnabled && record.Enabled {
				if err := dc.user.canEnableNewNetwork(ctx); err != nil {
					return ircError{&irc.Message{
						Command: "FAIL",
						Params:  []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)},
					}}
				}
			}

			_, err = dc.user.updateNetwork(ctx, &record)
			if err != nil {
				return ircError{&irc.Message{
diff --git a/service.go b/service.go
index cda7eec..df30bb8 100644
--- a/service.go
+++ b/service.go
@@ -583,6 +583,12 @@ func handleServiceNetworkCreate(ctx *serviceContext, params []string) error {
		return err
	}

	if !ctx.admin && record.Enabled {
		if err := ctx.user.canEnableNewNetwork(ctx); err != nil {
			return fmt.Errorf("could not create network: %v", err)
		}
	}

	network, err := ctx.user.createNetwork(ctx, record)
	if err != nil {
		return fmt.Errorf("could not create network: %v", err)
@@ -657,10 +663,17 @@ func handleServiceNetworkUpdate(ctx *serviceContext, params []string) error {
	}

	record := net.Network // copy network record because we'll mutate it
	wasEnabled := record.Enabled
	if err := fs.update(&record); err != nil {
		return err
	}

	if !ctx.admin && !wasEnabled && record.Enabled {
		if err := ctx.user.canEnableNewNetwork(ctx); err != nil {
			return fmt.Errorf("could not update network: %v", err)
		}
	}

	network, err := ctx.user.updateNetwork(ctx, &record)
	if err != nil {
		return fmt.Errorf("could not update network: %v", err)
diff --git a/user.go b/user.go
index c3a3931..cbe6f5c 100644
--- a/user.go
+++ b/user.go
@@ -962,6 +962,31 @@ func (u *user) removeNetwork(network *network) {
	panic("tried to remove a non-existing network")
}

func (u *user) canEnableNewNetwork(ctx context.Context) error {
	if u.Admin {
		return nil
	}
	max := u.srv.Config().MaxUserNetworks
	if max < 0 {
		return nil
	}
	networks, err := u.srv.db.ListNetworks(ctx, u.ID)
	if err != nil {
		u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
		return err
	}
	n := 0
	for _, network := range networks {
		if network.Enabled {
			n++
		}
	}
	if n >= max {
		return fmt.Errorf("maximum number of enabled networks reached")
	}
	return nil
}

func (u *user) checkNetwork(record *database.Network) error {
	url, err := record.URL()
	if err != nil {
@@ -1021,10 +1046,6 @@ func (u *user) createNetwork(ctx context.Context, record *database.Network) (*ne
		return nil, err
	}

	if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
		return nil, fmt.Errorf("maximum number of networks reached")
	}

	network := newNetwork(u, record, nil)
	err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
	if err != nil {

base-commit: 75a58cc2cb078a17cbb14249af85c8683d06deee
-- 
2.38.0

[PATCH 2/2] Add per-user max-networks limit Export this patch

This is formatted in user status as, e.g.:

  delthas (admin, disabled, max 3 networks): 2 networks
---
 database/database.go         |  1 +
 database/postgres.go         | 21 +++++++++---------
 database/postgres_schema.sql |  3 ++-
 database/sqlite.go           | 17 ++++++++------
 database/sqlite_schema.sql   |  3 ++-
 doc/soju.1.scd               |  9 ++++++--
 service.go                   | 43 +++++++++++++++++++++++++++++++-----
 user.go                      | 17 +++++++++-----
 8 files changed, 82 insertions(+), 32 deletions(-)

diff --git a/database/database.go b/database/database.go
index 10e7e38..9b1eb65 100644
--- a/database/database.go
+++ b/database/database.go
@@ -96,6 +96,7 @@ type User struct {
	Admin                  bool
	Enabled                bool
	DownstreamInteractedAt time.Time
	MaxNetworks            int
}

func NewUser(username string) *User {
diff --git a/database/postgres.go b/database/postgres.go
index c401119..f94d765 100644
--- a/database/postgres.go
+++ b/database/postgres.go
@@ -113,6 +113,7 @@ var postgresMigrations = []string{
		CREATE INDEX "MessageIndex" ON "Message" (target, time);
		CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search);
	`,
	`ALTER TABLE "User" ADD COLUMN max_networks INTEGER NOT NULL DEFAULT -1`,
}

type PostgresDB struct {
@@ -257,7 +258,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {

	rows, err := db.db.QueryContext(ctx,
		`SELECT id, username, password, admin, nick, realname, enabled,
			downstream_interacted_at
			downstream_interacted_at, max_networks
		FROM "User"`)
	if err != nil {
		return nil, err
@@ -269,7 +270,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
		var user User
		var password, nick, realname sql.NullString
		var downstreamInteractedAt sql.NullTime
		if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
		if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil {
			return nil, err
		}
		user.Password = password.String
@@ -294,11 +295,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
	var password, nick, realname sql.NullString
	var downstreamInteractedAt sql.NullTime
	row := db.db.QueryRowContext(ctx,
		`SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at
		`SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at, max_networks
		FROM "User"
		WHERE username = $1`,
		username)
	if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
	if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil {
		return nil, err
	}
	user.Password = password.String
@@ -348,19 +349,19 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
	if user.ID == 0 {
		err = db.db.QueryRowContext(ctx, `
			INSERT INTO "User" (username, password, admin, nick, realname,
				enabled, downstream_interacted_at)
			VALUES ($1, $2, $3, $4, $5, $6, $7)
				enabled, downstream_interacted_at, max_networks)
			VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
			RETURNING id`,
			user.Username, password, user.Admin, nick, realname, user.Enabled,
			downstreamInteractedAt).Scan(&user.ID)
			downstreamInteractedAt, user.MaxNetworks).Scan(&user.ID)
	} else {
		_, err = db.db.ExecContext(ctx, `
			UPDATE "User"
			SET password = $1, admin = $2, nick = $3, realname = $4,
				enabled = $5, downstream_interacted_at = $6
			WHERE id = $7`,
				enabled = $5, downstream_interacted_at = $6, max_networks = $7
			WHERE id = $8`,
			password, user.Admin, nick, realname, user.Enabled,
			downstreamInteractedAt, user.ID)
			downstreamInteractedAt, user.MaxNetworks, user.ID)
	}
	return err
}
diff --git a/database/postgres_schema.sql b/database/postgres_schema.sql
index 9643bb7..ec21ef4 100644
--- a/database/postgres_schema.sql
+++ b/database/postgres_schema.sql
@@ -7,7 +7,8 @@ CREATE TABLE "User" (
	realname VARCHAR(255),
	created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
	enabled BOOLEAN NOT NULL DEFAULT TRUE,
	downstream_interacted_at TIMESTAMP WITH TIME ZONE
	downstream_interacted_at TIMESTAMP WITH TIME ZONE,
	max_networks INTEGER NOT NULL DEFAULT -1
);

CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
diff --git a/database/sqlite.go b/database/sqlite.go
index f251f03..d8f99ba 100644
--- a/database/sqlite.go
+++ b/database/sqlite.go
@@ -242,6 +242,7 @@ var sqliteMigrations = []string{
			INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text);
		END;
	`,
	"ALTER TABLE User ADD COLUMN max_networks INTEGER NOT NULL DEFAULT -1",
}

type SqliteDB struct {
@@ -342,7 +343,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {

	rows, err := db.db.QueryContext(ctx,
		`SELECT id, username, password, admin, nick, realname, enabled,
			downstream_interacted_at
			downstream_interacted_at, max_networks
		FROM User`)
	if err != nil {
		return nil, err
@@ -354,7 +355,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
		var user User
		var password, nick, realname sql.NullString
		var downstreamInteractedAt sqliteTime
		if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
		if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil {
			return nil, err
		}
		user.Password = password.String
@@ -380,11 +381,11 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
	var downstreamInteractedAt sqliteTime
	row := db.db.QueryRowContext(ctx,
		`SELECT id, password, admin, nick, realname, enabled,
			downstream_interacted_at
			downstream_interacted_at, max_networks
		FROM User
		WHERE username = ?`,
		username)
	if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
	if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil {
		return nil, err
	}
	user.Password = password.String
@@ -434,6 +435,7 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
		sql.Named("enabled", user.Enabled),
		sql.Named("now", sqliteTime{time.Now()}),
		sql.Named("downstream_interacted_at", sqliteTime{user.DownstreamInteractedAt}),
		sql.Named("max_networks", user.MaxNetworks),
	}

	var err error
@@ -442,7 +444,8 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
			UPDATE User
			SET password = :password, admin = :admin, nick = :nick,
				realname = :realname, enabled = :enabled,
				downstream_interacted_at = :downstream_interacted_at
				downstream_interacted_at = :downstream_interacted_at,
				max_networks = :max_networks
			WHERE username = :username`,
			args...)
	} else {
@@ -450,9 +453,9 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
		res, err = db.db.ExecContext(ctx, `
			INSERT INTO
			User(username, password, admin, nick, realname, created_at,
				enabled, downstream_interacted_at)
				enabled, downstream_interacted_at, max_networks)
			VALUES (:username, :password, :admin, :nick, :realname, :now,
				:enabled, :downstream_interacted_at)`,
				:enabled, :downstream_interacted_at, :max_networks)`,
			args...)
		if err != nil {
			return err
diff --git a/database/sqlite_schema.sql b/database/sqlite_schema.sql
index b8445d1..fc10578 100644
--- a/database/sqlite_schema.sql
+++ b/database/sqlite_schema.sql
@@ -7,7 +7,8 @@ CREATE TABLE User (
	nick TEXT,
	created_at TEXT NOT NULL,
	enabled INTEGER NOT NULL DEFAULT 1,
	downstream_interacted_at TEXT
	downstream_interacted_at TEXT,
	max_networks INTEGER NOT NULL DEFAULT -1
);

CREATE TABLE Network (
diff --git a/doc/soju.1.scd b/doc/soju.1.scd
index 9ee7824..5bc9f32 100644
--- a/doc/soju.1.scd
+++ b/doc/soju.1.scd
@@ -519,6 +519,11 @@ character.
		not connect to any of their networks, and downstream connections will
		be immediately closed. By default, users are enabled.

	*-max-networks* <max-networks>
		Set a limit on the number of enabled networks this user can use. A limit
		of 0 means no network, and -1 means to default to the global
		_max-user-networks_ configuration value.

*user update* [username] [options...]
	Update a user. The options are the same as the _user create_ command.

@@ -530,8 +535,8 @@ character.
	- The _-username_ flag is never valid, usernames are immutable.
	- The _-nick_ and _-realname_ flag are only valid when updating the current
	  user.
	- The _-admin_ and _-enabled_ flags are only valid when updating another
	  user.
	- The _-admin_, _-enabled_ and _-max_networks_ flags are only valid when
	  updating another user.

*user delete* <username> [confirmation token]
	Delete a soju user.
diff --git a/service.go b/service.go
index df30bb8..3b50055 100644
--- a/service.go
+++ b/service.go
@@ -286,14 +286,14 @@ func init() {
					global: true,
				},
				"create": {
					usage:  "-username <username> -password <password> [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false]",
					usage:  "-username <username> -password <password> [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false]  [-max-networks <max-networks>]",
					desc:   "create a new soju user",
					handle: handleUserCreate,
					admin:  true,
					global: true,
				},
				"update": {
					usage:  "[username] [-password <password>] [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false]",
					usage:  "[username] [-password <password>] [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false] [-max-networks <max-networks>]",
					desc:   "update a user",
					handle: handleUserUpdate,
					global: true,
@@ -458,6 +458,26 @@ func (f boolPtrFlag) Set(s string) error {
	return nil
}

type intPtrFlag struct {
	ptr **int
}

func (f intPtrFlag) String() string {
	if f.ptr == nil || *f.ptr == nil {
		return "<nil>"
	}
	return strconv.Itoa(**f.ptr)
}

func (f intPtrFlag) Set(s string) error {
	v, err := strconv.Atoi(s)
	if err != nil {
		return err
	}
	*f.ptr = &v
	return nil
}

func getNetworkFromArg(ctx *serviceContext, params []string) (*network, []string, error) {
	name, params := popArg(params)
	if name == "" {
@@ -951,6 +971,9 @@ func handleUserStatus(ctx *serviceContext, params []string) error {
		if !user.Enabled {
			attrs = append(attrs, "disabled")
		}
		if user.MaxNetworks >= 0 {
			attrs = append(attrs, fmt.Sprintf("max %d networks", user.MaxNetworks))
		}

		line := user.Username
		if len(attrs) > 0 {
@@ -979,6 +1002,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error {
	realname := fs.String("realname", "", "")
	admin := fs.Bool("admin", false, "")
	enabled := fs.Bool("enabled", true, "")
	maxNetworks := fs.Int("max-networks", -1, "")

	if err := fs.Parse(params); err != nil {
		return err
@@ -1001,6 +1025,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error {
	user.Realname = *realname
	user.Admin = *admin
	user.Enabled = *enabled
	user.MaxNetworks = *maxNetworks
	if !*disablePassword {
		if err := user.SetPassword(*password); err != nil {
			return err
@@ -1025,6 +1050,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
	var password, nick, realname *string
	var admin, enabled *bool
	var disablePassword bool
	var maxNetworks *int
	fs := newFlagSet()
	fs.Var(stringPtrFlag{&password}, "password", "")
	fs.BoolVar(&disablePassword, "disable-password", false, "")
@@ -1032,6 +1058,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
	fs.Var(stringPtrFlag{&realname}, "realname", "")
	fs.Var(boolPtrFlag{&admin}, "admin", "")
	fs.Var(boolPtrFlag{&enabled}, "enabled", "")
	fs.Var(intPtrFlag{&maxNetworks}, "max-networks", "")

	username, params := popArg(params)
	if err := fs.Parse(params); err != nil {
@@ -1079,10 +1106,11 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {

		done := make(chan error, 1)
		event := eventUserUpdate{
			password: hashed,
			admin:    admin,
			enabled:  enabled,
			done:     done,
			password:    hashed,
			admin:       admin,
			enabled:     enabled,
			maxNetworks: maxNetworks,
			done:        done,
		}
		select {
		case <-ctx.Done():
@@ -1102,6 +1130,9 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
		if enabled != nil {
			return fmt.Errorf("cannot update -enabled of own user")
		}
		if maxNetworks != nil {
			return fmt.Errorf("cannot update -max-networks of own user")
		}

		err := ctx.user.updateUser(ctx, func(record *database.User) error {
			if password != nil {
diff --git a/user.go b/user.go
index cbe6f5c..0ad6dba 100644
--- a/user.go
+++ b/user.go
@@ -75,10 +75,11 @@ type eventBroadcast struct {
type eventStop struct{}

type eventUserUpdate struct {
	password *string
	admin    *bool
	enabled  *bool
	done     chan error
	password    *string
	admin       *bool
	enabled     *bool
	maxNetworks *int
	done        chan error
}

type eventTryRegainNick struct {
@@ -821,6 +822,9 @@ func (u *user) run() {
				if e.enabled != nil {
					record.Enabled = *e.enabled
				}
				if e.maxNetworks != nil {
					record.MaxNetworks = *e.maxNetworks
				}
				return nil
			})

@@ -966,7 +970,10 @@ func (u *user) canEnableNewNetwork(ctx context.Context) error {
	if u.Admin {
		return nil
	}
	max := u.srv.Config().MaxUserNetworks
	max := u.MaxNetworks
	if max < 0 {
		max = u.srv.Config().MaxUserNetworks
	}
	if max < 0 {
		return nil
	}
-- 
2.38.0