delthas: 2 Make max-user-networks a limit of enabled networks Add per-user max-networks limit 12 files changed, 141 insertions(+), 37 deletions(-)
Copy & paste the following snippet into your terminal to import this patchset into git:
curl -s https://lists.sr.ht/~emersion/soju-dev/patches/53205/mbox | git am -3Learn more about email & git
--- doc/soju.1.scd | 3 ++- downstream.go | 19 +++++++++++++++++++ service.go | 13 +++++++++++++ user.go | 29 +++++++++++++++++++++++++---- 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/doc/soju.1.scd b/doc/soju.1.scd index f06d5f8..9ee7824 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -189,7 +189,8 @@ The following directives are supported: By default, all IPs are rejected. *max-user-networks* <limit> - Maximum number of networks per user. By default, there is no limit. + Maximum number of enabled networks per user. By default, there is no limit. + The limit is ignored for admin users or when updating networks as an admin. *motd* <path> Path to the MOTD file. The bouncer MOTD is sent to clients which aren't diff --git a/downstream.go b/downstream.go index f48653e..88eddcb 100644 --- a/downstream.go +++ b/downstream.go @@ -3041,6 +3041,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. record.Realname = "" } + if record.Enabled { + if err := dc.user.canEnableNewNetwork(ctx); err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)}, + }} + } + } + network, err := dc.user.createNetwork(ctx, record) if err != nil { return ircError{&irc.Message{ @@ -3073,6 +3082,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } record := net.Network // copy network record because we'll mutate it + wasEnabled := record.Enabled if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil { return err } @@ -3084,6 +3094,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. record.Realname = "" } + if !wasEnabled && record.Enabled { + if err := dc.user.canEnableNewNetwork(ctx); err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)}, + }} + } + } + _, err = dc.user.updateNetwork(ctx, &record) if err != nil { return ircError{&irc.Message{ diff --git a/service.go b/service.go index cda7eec..df30bb8 100644 --- a/service.go +++ b/service.go @@ -583,6 +583,12 @@ func handleServiceNetworkCreate(ctx *serviceContext, params []string) error { return err } + if !ctx.admin && record.Enabled { + if err := ctx.user.canEnableNewNetwork(ctx); err != nil { + return fmt.Errorf("could not create network: %v", err) + } + } + network, err := ctx.user.createNetwork(ctx, record) if err != nil { return fmt.Errorf("could not create network: %v", err) @@ -657,10 +663,17 @@ func handleServiceNetworkUpdate(ctx *serviceContext, params []string) error { } record := net.Network // copy network record because we'll mutate it + wasEnabled := record.Enabled if err := fs.update(&record); err != nil { return err } + if !ctx.admin && !wasEnabled && record.Enabled { + if err := ctx.user.canEnableNewNetwork(ctx); err != nil { + return fmt.Errorf("could not update network: %v", err) + } + } + network, err := ctx.user.updateNetwork(ctx, &record) if err != nil { return fmt.Errorf("could not update network: %v", err) diff --git a/user.go b/user.go index c3a3931..cbe6f5c 100644 --- a/user.go +++ b/user.go @@ -962,6 +962,31 @@ func (u *user) removeNetwork(network *network) { panic("tried to remove a non-existing network") } +func (u *user) canEnableNewNetwork(ctx context.Context) error { + if u.Admin { + return nil + } + max := u.srv.Config().MaxUserNetworks + if max < 0 { + return nil + } + networks, err := u.srv.db.ListNetworks(ctx, u.ID) + if err != nil { + u.logger.Printf("failed to list networks for user %q: %v", u.Username, err) + return err + } + n := 0 + for _, network := range networks { + if network.Enabled { + n++ + } + } + if n >= max { + return fmt.Errorf("maximum number of enabled networks reached") + } + return nil +} + func (u *user) checkNetwork(record *database.Network) error { url, err := record.URL() if err != nil { @@ -1021,10 +1046,6 @@ func (u *user) createNetwork(ctx context.Context, record *database.Network) (*ne return nil, err } - if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max { - return nil, fmt.Errorf("maximum number of networks reached") - } - network := newNetwork(u, record, nil) err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network) if err != nil { base-commit: 75a58cc2cb078a17cbb14249af85c8683d06deee -- 2.38.0
This is formatted in user status as, e.g.: delthas (admin, disabled, max 3 networks): 2 networks --- database/database.go | 1 + database/postgres.go | 21 +++++++++--------- database/postgres_schema.sql | 3 ++- database/sqlite.go | 17 ++++++++------ database/sqlite_schema.sql | 3 ++- doc/soju.1.scd | 9 ++++++-- service.go | 43 +++++++++++++++++++++++++++++++----- user.go | 17 +++++++++----- 8 files changed, 82 insertions(+), 32 deletions(-) diff --git a/database/database.go b/database/database.go index 10e7e38..9b1eb65 100644 --- a/database/database.go +++ b/database/database.go @@ -96,6 +96,7 @@ type User struct { Admin bool Enabled bool DownstreamInteractedAt time.Time + MaxNetworks int } func NewUser(username string) *User { diff --git a/database/postgres.go b/database/postgres.go index c401119..f94d765 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -113,6 +113,7 @@ var postgresMigrations = []string{ CREATE INDEX "MessageIndex" ON "Message" (target, time); CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search); `, + `ALTER TABLE "User" ADD COLUMN max_networks INTEGER NOT NULL DEFAULT -1`, } type PostgresDB struct { @@ -257,7 +258,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { rows, err := db.db.QueryContext(ctx, `SELECT id, username, password, admin, nick, realname, enabled, - downstream_interacted_at + downstream_interacted_at, max_networks FROM "User"`) if err != nil { return nil, err @@ -269,7 +270,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { var user User var password, nick, realname sql.NullString var downstreamInteractedAt sql.NullTime - if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil { + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil { return nil, err } user.Password = password.String @@ -294,11 +295,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro var password, nick, realname sql.NullString var downstreamInteractedAt sql.NullTime row := db.db.QueryRowContext(ctx, - `SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at + `SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at, max_networks FROM "User" WHERE username = $1`, username) - if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil { + if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil { return nil, err } user.Password = password.String @@ -348,19 +349,19 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { if user.ID == 0 { err = db.db.QueryRowContext(ctx, ` INSERT INTO "User" (username, password, admin, nick, realname, - enabled, downstream_interacted_at) - VALUES ($1, $2, $3, $4, $5, $6, $7) + enabled, downstream_interacted_at, max_networks) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id`, user.Username, password, user.Admin, nick, realname, user.Enabled, - downstreamInteractedAt).Scan(&user.ID) + downstreamInteractedAt, user.MaxNetworks).Scan(&user.ID) } else { _, err = db.db.ExecContext(ctx, ` UPDATE "User" SET password = $1, admin = $2, nick = $3, realname = $4, - enabled = $5, downstream_interacted_at = $6 - WHERE id = $7`, + enabled = $5, downstream_interacted_at = $6, max_networks = $7 + WHERE id = $8`, password, user.Admin, nick, realname, user.Enabled, - downstreamInteractedAt, user.ID) + downstreamInteractedAt, user.MaxNetworks, user.ID) } return err } diff --git a/database/postgres_schema.sql b/database/postgres_schema.sql index 9643bb7..ec21ef4 100644 --- a/database/postgres_schema.sql +++ b/database/postgres_schema.sql @@ -7,7 +7,8 @@ CREATE TABLE "User" ( realname VARCHAR(255), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), enabled BOOLEAN NOT NULL DEFAULT TRUE, - downstream_interacted_at TIMESTAMP WITH TIME ZONE + downstream_interacted_at TIMESTAMP WITH TIME ZONE, + max_networks INTEGER NOT NULL DEFAULT -1 ); CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); diff --git a/database/sqlite.go b/database/sqlite.go index f251f03..d8f99ba 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -242,6 +242,7 @@ var sqliteMigrations = []string{ INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text); END; `, + "ALTER TABLE User ADD COLUMN max_networks INTEGER NOT NULL DEFAULT -1", } type SqliteDB struct { @@ -342,7 +343,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { rows, err := db.db.QueryContext(ctx, `SELECT id, username, password, admin, nick, realname, enabled, - downstream_interacted_at + downstream_interacted_at, max_networks FROM User`) if err != nil { return nil, err @@ -354,7 +355,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { var user User var password, nick, realname sql.NullString var downstreamInteractedAt sqliteTime - if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil { + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil { return nil, err } user.Password = password.String @@ -380,11 +381,11 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) var downstreamInteractedAt sqliteTime row := db.db.QueryRowContext(ctx, `SELECT id, password, admin, nick, realname, enabled, - downstream_interacted_at + downstream_interacted_at, max_networks FROM User WHERE username = ?`, username) - if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil { + if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil { return nil, err } user.Password = password.String @@ -434,6 +435,7 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { sql.Named("enabled", user.Enabled), sql.Named("now", sqliteTime{time.Now()}), sql.Named("downstream_interacted_at", sqliteTime{user.DownstreamInteractedAt}), + sql.Named("max_networks", user.MaxNetworks), } var err error @@ -442,7 +444,8 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { UPDATE User SET password = :password, admin = :admin, nick = :nick, realname = :realname, enabled = :enabled, - downstream_interacted_at = :downstream_interacted_at + downstream_interacted_at = :downstream_interacted_at, + max_networks = :max_networks WHERE username = :username`, args...) } else { @@ -450,9 +453,9 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { res, err = db.db.ExecContext(ctx, ` INSERT INTO User(username, password, admin, nick, realname, created_at, - enabled, downstream_interacted_at) + enabled, downstream_interacted_at, max_networks) VALUES (:username, :password, :admin, :nick, :realname, :now, - :enabled, :downstream_interacted_at)`, + :enabled, :downstream_interacted_at, :max_networks)`, args...) if err != nil { return err diff --git a/database/sqlite_schema.sql b/database/sqlite_schema.sql index b8445d1..fc10578 100644 --- a/database/sqlite_schema.sql +++ b/database/sqlite_schema.sql @@ -7,7 +7,8 @@ CREATE TABLE User ( nick TEXT, created_at TEXT NOT NULL, enabled INTEGER NOT NULL DEFAULT 1, - downstream_interacted_at TEXT + downstream_interacted_at TEXT, + max_networks INTEGER NOT NULL DEFAULT -1 ); CREATE TABLE Network ( diff --git a/doc/soju.1.scd b/doc/soju.1.scd index 9ee7824..5bc9f32 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -519,6 +519,11 @@ character. not connect to any of their networks, and downstream connections will be immediately closed. By default, users are enabled. + *-max-networks* <max-networks> + Set a limit on the number of enabled networks this user can use. A limit + of 0 means no network, and -1 means to default to the global + _max-user-networks_ configuration value. + *user update* [username] [options...] Update a user. The options are the same as the _user create_ command. @@ -530,8 +535,8 @@ character. - The _-username_ flag is never valid, usernames are immutable. - The _-nick_ and _-realname_ flag are only valid when updating the current user. - - The _-admin_ and _-enabled_ flags are only valid when updating another - user. + - The _-admin_, _-enabled_ and _-max_networks_ flags are only valid when + updating another user. *user delete* <username> [confirmation token] Delete a soju user. diff --git a/service.go b/service.go index df30bb8..3b50055 100644 --- a/service.go +++ b/service.go @@ -286,14 +286,14 @@ func init() { global: true, }, "create": { - usage: "-username <username> -password <password> [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false]", + usage: "-username <username> -password <password> [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false] [-max-networks <max-networks>]", desc: "create a new soju user", handle: handleUserCreate, admin: true, global: true, }, "update": { - usage: "[username] [-password <password>] [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false]", + usage: "[username] [-password <password>] [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false] [-max-networks <max-networks>]", desc: "update a user", handle: handleUserUpdate, global: true, @@ -458,6 +458,26 @@ func (f boolPtrFlag) Set(s string) error { return nil } +type intPtrFlag struct { + ptr **int +} + +func (f intPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "<nil>" + } + return strconv.Itoa(**f.ptr) +} + +func (f intPtrFlag) Set(s string) error { + v, err := strconv.Atoi(s) + if err != nil { + return err + } + *f.ptr = &v + return nil +} + func getNetworkFromArg(ctx *serviceContext, params []string) (*network, []string, error) { name, params := popArg(params) if name == "" { @@ -951,6 +971,9 @@ func handleUserStatus(ctx *serviceContext, params []string) error { if !user.Enabled { attrs = append(attrs, "disabled") } + if user.MaxNetworks >= 0 { + attrs = append(attrs, fmt.Sprintf("max %d networks", user.MaxNetworks)) + } line := user.Username if len(attrs) > 0 { @@ -979,6 +1002,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error { realname := fs.String("realname", "", "") admin := fs.Bool("admin", false, "") enabled := fs.Bool("enabled", true, "") + maxNetworks := fs.Int("max-networks", -1, "") if err := fs.Parse(params); err != nil { return err @@ -1001,6 +1025,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error { user.Realname = *realname user.Admin = *admin user.Enabled = *enabled + user.MaxNetworks = *maxNetworks if !*disablePassword { if err := user.SetPassword(*password); err != nil { return err @@ -1025,6 +1050,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error { var password, nick, realname *string var admin, enabled *bool var disablePassword bool + var maxNetworks *int fs := newFlagSet() fs.Var(stringPtrFlag{&password}, "password", "") fs.BoolVar(&disablePassword, "disable-password", false, "") @@ -1032,6 +1058,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error { fs.Var(stringPtrFlag{&realname}, "realname", "") fs.Var(boolPtrFlag{&admin}, "admin", "") fs.Var(boolPtrFlag{&enabled}, "enabled", "") + fs.Var(intPtrFlag{&maxNetworks}, "max-networks", "") username, params := popArg(params) if err := fs.Parse(params); err != nil { @@ -1079,10 +1106,11 @@ func handleUserUpdate(ctx *serviceContext, params []string) error { done := make(chan error, 1) event := eventUserUpdate{ - password: hashed, - admin: admin, - enabled: enabled, - done: done, + password: hashed, + admin: admin, + enabled: enabled, + maxNetworks: maxNetworks, + done: done, } select { case <-ctx.Done(): @@ -1102,6 +1130,9 @@ func handleUserUpdate(ctx *serviceContext, params []string) error { if enabled != nil { return fmt.Errorf("cannot update -enabled of own user") } + if maxNetworks != nil { + return fmt.Errorf("cannot update -max-networks of own user") + } err := ctx.user.updateUser(ctx, func(record *database.User) error { if password != nil { diff --git a/user.go b/user.go index cbe6f5c..0ad6dba 100644 --- a/user.go +++ b/user.go @@ -75,10 +75,11 @@ type eventBroadcast struct { type eventStop struct{} type eventUserUpdate struct { - password *string - admin *bool - enabled *bool - done chan error + password *string + admin *bool + enabled *bool + maxNetworks *int + done chan error } type eventTryRegainNick struct { @@ -821,6 +822,9 @@ func (u *user) run() { if e.enabled != nil { record.Enabled = *e.enabled } + if e.maxNetworks != nil { + record.MaxNetworks = *e.maxNetworks + } return nil }) @@ -966,7 +970,10 @@ func (u *user) canEnableNewNetwork(ctx context.Context) error { if u.Admin { return nil } - max := u.srv.Config().MaxUserNetworks + max := u.MaxNetworks + if max < 0 { + max = u.srv.Config().MaxUserNetworks + } if max < 0 { return nil } -- 2.38.0