---
doc/soju.1.scd | 3 ++-
downstream.go | 19 +++++++++++++++++++
service.go | 13 +++++++++++++
user.go | 29 +++++++++++++++++++++++++----
4 files changed, 59 insertions(+), 5 deletions(-)
diff --git a/doc/soju.1.scd b/doc/soju.1.scd
index f06d5f8..9ee7824 100644
--- a/doc/soju.1.scd
+++ b/doc/soju.1.scd
@@ -189,7 +189,8 @@ The following directives are supported:
By default, all IPs are rejected.
*max-user-networks* <limit>
- Maximum number of networks per user. By default, there is no limit.
+ Maximum number of enabled networks per user. By default, there is no limit.
+ The limit is ignored for admin users or when updating networks as an admin.
*motd* <path>
Path to the MOTD file. The bouncer MOTD is sent to clients which aren't
diff --git a/downstream.go b/downstream.go
index f48653e..88eddcb 100644
--- a/downstream.go
+++ b/downstream.go
@@ -3041,6 +3041,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
record.Realname = ""
}
+ if record.Enabled {
+ if err := dc.user.canEnableNewNetwork(ctx); err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)},
+ }}
+ }
+ }
+
network, err := dc.user.createNetwork(ctx, record)
if err != nil {
return ircError{&irc.Message{
@@ -3073,6 +3082,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
record := net.Network // copy network record because we'll mutate it
+ wasEnabled := record.Enabled
if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil {
return err
}
@@ -3084,6 +3094,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
record.Realname = ""
}
+ if !wasEnabled && record.Enabled {
+ if err := dc.user.canEnableNewNetwork(ctx); err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)},
+ }}
+ }
+ }
+
_, err = dc.user.updateNetwork(ctx, &record)
if err != nil {
return ircError{&irc.Message{
diff --git a/service.go b/service.go
index cda7eec..df30bb8 100644
--- a/service.go
+++ b/service.go
@@ -583,6 +583,12 @@ func handleServiceNetworkCreate(ctx *serviceContext, params []string) error {
return err
}
+ if !ctx.admin && record.Enabled {
+ if err := ctx.user.canEnableNewNetwork(ctx); err != nil {
+ return fmt.Errorf("could not create network: %v", err)
+ }
+ }
+
network, err := ctx.user.createNetwork(ctx, record)
if err != nil {
return fmt.Errorf("could not create network: %v", err)
@@ -657,10 +663,17 @@ func handleServiceNetworkUpdate(ctx *serviceContext, params []string) error {
}
record := net.Network // copy network record because we'll mutate it
+ wasEnabled := record.Enabled
if err := fs.update(&record); err != nil {
return err
}
+ if !ctx.admin && !wasEnabled && record.Enabled {
+ if err := ctx.user.canEnableNewNetwork(ctx); err != nil {
+ return fmt.Errorf("could not update network: %v", err)
+ }
+ }
+
network, err := ctx.user.updateNetwork(ctx, &record)
if err != nil {
return fmt.Errorf("could not update network: %v", err)
diff --git a/user.go b/user.go
index c3a3931..cbe6f5c 100644
--- a/user.go
+++ b/user.go
@@ -962,6 +962,31 @@ func (u *user) removeNetwork(network *network) {
panic("tried to remove a non-existing network")
}
+func (u *user) canEnableNewNetwork(ctx context.Context) error {
+ if u.Admin {
+ return nil
+ }
+ max := u.srv.Config().MaxUserNetworks
+ if max < 0 {
+ return nil
+ }
+ networks, err := u.srv.db.ListNetworks(ctx, u.ID)
+ if err != nil {
+ u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
+ return err
+ }
+ n := 0
+ for _, network := range networks {
+ if network.Enabled {
+ n++
+ }
+ }
+ if n >= max {
+ return fmt.Errorf("maximum number of enabled networks reached")
+ }
+ return nil
+}
+
func (u *user) checkNetwork(record *database.Network) error {
url, err := record.URL()
if err != nil {
@@ -1021,10 +1046,6 @@ func (u *user) createNetwork(ctx context.Context, record *database.Network) (*ne
return nil, err
}
- if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
- return nil, fmt.Errorf("maximum number of networks reached")
- }
-
network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
if err != nil {
base-commit: 75a58cc2cb078a17cbb14249af85c8683d06deee
--
2.38.0
This is formatted in user status as, e.g.:
delthas (admin, disabled, max 3 networks): 2 networks
---
database/database.go | 1 +
database/postgres.go | 21 +++++++++---------
database/postgres_schema.sql | 3 ++-
database/sqlite.go | 17 ++++++++------
database/sqlite_schema.sql | 3 ++-
doc/soju.1.scd | 9 ++++++--
service.go | 43 +++++++++++++++++++++++++++++++-----
user.go | 17 +++++++++-----
8 files changed, 82 insertions(+), 32 deletions(-)
diff --git a/database/database.go b/database/database.go
index 10e7e38..9b1eb65 100644
--- a/database/database.go
+++ b/database/database.go
@@ -96,6 +96,7 @@ type User struct {
Admin bool
Enabled bool
DownstreamInteractedAt time.Time
+ MaxNetworks int
}
func NewUser(username string) *User {
diff --git a/database/postgres.go b/database/postgres.go
index c401119..f94d765 100644
--- a/database/postgres.go
+++ b/database/postgres.go
@@ -113,6 +113,7 @@ var postgresMigrations = []string{
CREATE INDEX "MessageIndex" ON "Message" (target, time);
CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search);
`,
+ `ALTER TABLE "User" ADD COLUMN max_networks INTEGER NOT NULL DEFAULT -1`,
}
type PostgresDB struct {
@@ -257,7 +258,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
rows, err := db.db.QueryContext(ctx,
`SELECT id, username, password, admin, nick, realname, enabled,
- downstream_interacted_at
+ downstream_interacted_at, max_networks
FROM "User"`)
if err != nil {
return nil, err
@@ -269,7 +270,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
var user User
var password, nick, realname sql.NullString
var downstreamInteractedAt sql.NullTime
- if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil {
return nil, err
}
user.Password = password.String
@@ -294,11 +295,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
var password, nick, realname sql.NullString
var downstreamInteractedAt sql.NullTime
row := db.db.QueryRowContext(ctx,
- `SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at
+ `SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at, max_networks
FROM "User"
WHERE username = $1`,
username)
- if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
+ if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil {
return nil, err
}
user.Password = password.String
@@ -348,19 +349,19 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
if user.ID == 0 {
err = db.db.QueryRowContext(ctx, `
INSERT INTO "User" (username, password, admin, nick, realname,
- enabled, downstream_interacted_at)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
+ enabled, downstream_interacted_at, max_networks)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id`,
user.Username, password, user.Admin, nick, realname, user.Enabled,
- downstreamInteractedAt).Scan(&user.ID)
+ downstreamInteractedAt, user.MaxNetworks).Scan(&user.ID)
} else {
_, err = db.db.ExecContext(ctx, `
UPDATE "User"
SET password = $1, admin = $2, nick = $3, realname = $4,
- enabled = $5, downstream_interacted_at = $6
- WHERE id = $7`,
+ enabled = $5, downstream_interacted_at = $6, max_networks = $7
+ WHERE id = $8`,
password, user.Admin, nick, realname, user.Enabled,
- downstreamInteractedAt, user.ID)
+ downstreamInteractedAt, user.MaxNetworks, user.ID)
}
return err
}
diff --git a/database/postgres_schema.sql b/database/postgres_schema.sql
index 9643bb7..ec21ef4 100644
--- a/database/postgres_schema.sql
+++ b/database/postgres_schema.sql
@@ -7,7 +7,8 @@ CREATE TABLE "User" (
realname VARCHAR(255),
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
enabled BOOLEAN NOT NULL DEFAULT TRUE,
- downstream_interacted_at TIMESTAMP WITH TIME ZONE
+ downstream_interacted_at TIMESTAMP WITH TIME ZONE,
+ max_networks INTEGER NOT NULL DEFAULT -1
);
CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
diff --git a/database/sqlite.go b/database/sqlite.go
index f251f03..d8f99ba 100644
--- a/database/sqlite.go
+++ b/database/sqlite.go
@@ -242,6 +242,7 @@ var sqliteMigrations = []string{
INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text);
END;
`,
+ "ALTER TABLE User ADD COLUMN max_networks INTEGER NOT NULL DEFAULT -1",
}
type SqliteDB struct {
@@ -342,7 +343,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
rows, err := db.db.QueryContext(ctx,
`SELECT id, username, password, admin, nick, realname, enabled,
- downstream_interacted_at
+ downstream_interacted_at, max_networks
FROM User`)
if err != nil {
return nil, err
@@ -354,7 +355,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
var user User
var password, nick, realname sql.NullString
var downstreamInteractedAt sqliteTime
- if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil {
return nil, err
}
user.Password = password.String
@@ -380,11 +381,11 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
var downstreamInteractedAt sqliteTime
row := db.db.QueryRowContext(ctx,
`SELECT id, password, admin, nick, realname, enabled,
- downstream_interacted_at
+ downstream_interacted_at, max_networks
FROM User
WHERE username = ?`,
username)
- if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
+ if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt, &user.MaxNetworks); err != nil {
return nil, err
}
user.Password = password.String
@@ -434,6 +435,7 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
sql.Named("enabled", user.Enabled),
sql.Named("now", sqliteTime{time.Now()}),
sql.Named("downstream_interacted_at", sqliteTime{user.DownstreamInteractedAt}),
+ sql.Named("max_networks", user.MaxNetworks),
}
var err error
@@ -442,7 +444,8 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
UPDATE User
SET password = :password, admin = :admin, nick = :nick,
realname = :realname, enabled = :enabled,
- downstream_interacted_at = :downstream_interacted_at
+ downstream_interacted_at = :downstream_interacted_at,
+ max_networks = :max_networks
WHERE username = :username`,
args...)
} else {
@@ -450,9 +453,9 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
res, err = db.db.ExecContext(ctx, `
INSERT INTO
User(username, password, admin, nick, realname, created_at,
- enabled, downstream_interacted_at)
+ enabled, downstream_interacted_at, max_networks)
VALUES (:username, :password, :admin, :nick, :realname, :now,
- :enabled, :downstream_interacted_at)`,
+ :enabled, :downstream_interacted_at, :max_networks)`,
args...)
if err != nil {
return err
diff --git a/database/sqlite_schema.sql b/database/sqlite_schema.sql
index b8445d1..fc10578 100644
--- a/database/sqlite_schema.sql
+++ b/database/sqlite_schema.sql
@@ -7,7 +7,8 @@ CREATE TABLE User (
nick TEXT,
created_at TEXT NOT NULL,
enabled INTEGER NOT NULL DEFAULT 1,
- downstream_interacted_at TEXT
+ downstream_interacted_at TEXT,
+ max_networks INTEGER NOT NULL DEFAULT -1
);
CREATE TABLE Network (
diff --git a/doc/soju.1.scd b/doc/soju.1.scd
index 9ee7824..5bc9f32 100644
--- a/doc/soju.1.scd
+++ b/doc/soju.1.scd
@@ -519,6 +519,11 @@ character.
not connect to any of their networks, and downstream connections will
be immediately closed. By default, users are enabled.
+ *-max-networks* <max-networks>
+ Set a limit on the number of enabled networks this user can use. A limit
+ of 0 means no network, and -1 means to default to the global
+ _max-user-networks_ configuration value.
+
*user update* [username] [options...]
Update a user. The options are the same as the _user create_ command.
@@ -530,8 +535,8 @@ character.
- The _-username_ flag is never valid, usernames are immutable.
- The _-nick_ and _-realname_ flag are only valid when updating the current
user.
- - The _-admin_ and _-enabled_ flags are only valid when updating another
- user.
+ - The _-admin_, _-enabled_ and _-max_networks_ flags are only valid when
+ updating another user.
*user delete* <username> [confirmation token]
Delete a soju user.
diff --git a/service.go b/service.go
index df30bb8..3b50055 100644
--- a/service.go
+++ b/service.go
@@ -286,14 +286,14 @@ func init() {
global: true,
},
"create": {
- usage: "-username <username> -password <password> [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false]",
+ usage: "-username <username> -password <password> [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false] [-max-networks <max-networks>]",
desc: "create a new soju user",
handle: handleUserCreate,
admin: true,
global: true,
},
"update": {
- usage: "[username] [-password <password>] [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false]",
+ usage: "[username] [-password <password>] [-disable-password] [-admin true|false] [-nick <nick>] [-realname <realname>] [-enabled true|false] [-max-networks <max-networks>]",
desc: "update a user",
handle: handleUserUpdate,
global: true,
@@ -458,6 +458,26 @@ func (f boolPtrFlag) Set(s string) error {
return nil
}
+type intPtrFlag struct {
+ ptr **int
+}
+
+func (f intPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return "<nil>"
+ }
+ return strconv.Itoa(**f.ptr)
+}
+
+func (f intPtrFlag) Set(s string) error {
+ v, err := strconv.Atoi(s)
+ if err != nil {
+ return err
+ }
+ *f.ptr = &v
+ return nil
+}
+
func getNetworkFromArg(ctx *serviceContext, params []string) (*network, []string, error) {
name, params := popArg(params)
if name == "" {
@@ -951,6 +971,9 @@ func handleUserStatus(ctx *serviceContext, params []string) error {
if !user.Enabled {
attrs = append(attrs, "disabled")
}
+ if user.MaxNetworks >= 0 {
+ attrs = append(attrs, fmt.Sprintf("max %d networks", user.MaxNetworks))
+ }
line := user.Username
if len(attrs) > 0 {
@@ -979,6 +1002,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error {
realname := fs.String("realname", "", "")
admin := fs.Bool("admin", false, "")
enabled := fs.Bool("enabled", true, "")
+ maxNetworks := fs.Int("max-networks", -1, "")
if err := fs.Parse(params); err != nil {
return err
@@ -1001,6 +1025,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error {
user.Realname = *realname
user.Admin = *admin
user.Enabled = *enabled
+ user.MaxNetworks = *maxNetworks
if !*disablePassword {
if err := user.SetPassword(*password); err != nil {
return err
@@ -1025,6 +1050,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
var password, nick, realname *string
var admin, enabled *bool
var disablePassword bool
+ var maxNetworks *int
fs := newFlagSet()
fs.Var(stringPtrFlag{&password}, "password", "")
fs.BoolVar(&disablePassword, "disable-password", false, "")
@@ -1032,6 +1058,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
fs.Var(stringPtrFlag{&realname}, "realname", "")
fs.Var(boolPtrFlag{&admin}, "admin", "")
fs.Var(boolPtrFlag{&enabled}, "enabled", "")
+ fs.Var(intPtrFlag{&maxNetworks}, "max-networks", "")
username, params := popArg(params)
if err := fs.Parse(params); err != nil {
@@ -1079,10 +1106,11 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
done := make(chan error, 1)
event := eventUserUpdate{
- password: hashed,
- admin: admin,
- enabled: enabled,
- done: done,
+ password: hashed,
+ admin: admin,
+ enabled: enabled,
+ maxNetworks: maxNetworks,
+ done: done,
}
select {
case <-ctx.Done():
@@ -1102,6 +1130,9 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
if enabled != nil {
return fmt.Errorf("cannot update -enabled of own user")
}
+ if maxNetworks != nil {
+ return fmt.Errorf("cannot update -max-networks of own user")
+ }
err := ctx.user.updateUser(ctx, func(record *database.User) error {
if password != nil {
diff --git a/user.go b/user.go
index cbe6f5c..0ad6dba 100644
--- a/user.go
+++ b/user.go
@@ -75,10 +75,11 @@ type eventBroadcast struct {
type eventStop struct{}
type eventUserUpdate struct {
- password *string
- admin *bool
- enabled *bool
- done chan error
+ password *string
+ admin *bool
+ enabled *bool
+ maxNetworks *int
+ done chan error
}
type eventTryRegainNick struct {
@@ -821,6 +822,9 @@ func (u *user) run() {
if e.enabled != nil {
record.Enabled = *e.enabled
}
+ if e.maxNetworks != nil {
+ record.MaxNetworks = *e.maxNetworks
+ }
return nil
})
@@ -966,7 +970,10 @@ func (u *user) canEnableNewNetwork(ctx context.Context) error {
if u.Admin {
return nil
}
- max := u.srv.Config().MaxUserNetworks
+ max := u.MaxNetworks
+ if max < 0 {
+ max = u.srv.Config().MaxUserNetworks
+ }
if max < 0 {
return nil
}
--
2.38.0