~emersion/soju-dev

Allow waiting for usermode before joining channels v1 PROPOSED

Umar Getagazov: 1
 Allow waiting for usermode before joining channels

 8 files changed, 123 insertions(+), 53 deletions(-)
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~emersion/soju-dev/patches/30466/mbox | git am -3
Learn more about email & git

[PATCH] Allow waiting for usermode before joining channels Export this patch

This commit adds a new field to networks named WaitUntilMode. This field
is consulted on RPL_WELCOME numeric message and MODE message. If
WaitUntilMode is not empty, soju doesn't join channels on RPL_WELCOME.
Instead, on any MODE message that touches our user, this field is
consulted. If user's modes contain all modes from WaitUntilMode, soju
proceeds to join the channels.

As an example, users of networks that do not support SASL (or somehow
otherwise authenticate the user only after connecting) can set
WaitUntilMode to "R" (registered) and still allow soju to join channels
that only allow registered users to join.
---
Feel free to NACK this patch. I couldn't come up with a better solution,
except for a -wait-until-server-sends-message-that-matches
".*You are successfully identified as .*" which is much messier.

Also, of IRC servers I lurk on, only libera.chat (Solanum) doesn't set
any registration-related flags on the user. OFTC (oftc-hybrid, fork of
ircd-hybrid) (duh), Rizon (plexus, fork of ircd-hybrid) and tilde.chat
(InspIRCd) do.

 db.go          |  1 +
 db_postgres.go | 22 +++++++++------
 db_sqlite.go   | 20 ++++++++++----
 doc/soju.1.scd |  6 ++++
 irc.go         | 10 +++++++
 service.go     | 23 ++++++----------
 upstream.go    | 75 +++++++++++++++++++++++++++++++++-----------------
 user.go        | 19 +++++++++++++
 8 files changed, 123 insertions(+), 53 deletions(-)

diff --git a/db.go b/db.go
index 6dcab77..f202609 100644
--- a/db.go
+++ b/db.go
@@ -88,6 +88,7 @@ type Network struct {
	Username        string
	Realname        string
	Pass            string
	WaitUntilMode   userModes
	ConnectCommands []string
	SASL            SASL
	Enabled         bool
diff --git a/db_postgres.go b/db_postgres.go
index 8833adf..15615b6 100644
--- a/db_postgres.go
+++ b/db_postgres.go
@@ -44,6 +44,7 @@ CREATE TABLE "Network" (
	username VARCHAR(255),
	realname VARCHAR(255),
	pass VARCHAR(255),
	wait_until_mode VARCHAR(31),
	connect_commands VARCHAR(1023),
	sasl_mechanism sasl_mechanism,
	sasl_plain_username VARCHAR(255),
@@ -106,6 +107,7 @@ var postgresMigrations = []string{
			UNIQUE(network, target)
		);
	`,
	`ALTER TABLE "Network" ADD COLUMN wait_until_mode VARCHAR(31)`,
}

type PostgresDB struct {
@@ -284,7 +286,8 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network

	rows, err := db.db.QueryContext(ctx, `
		SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
			sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
			sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key,
			enabled, wait_until_mode
		FROM "Network"
		WHERE "user" = $1`, userID)
	if err != nil {
@@ -295,11 +298,13 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network
	var networks []Network
	for rows.Next() {
		var net Network
		var name, nick, username, realname, pass, connectCommands sql.NullString
		var name, nick, username, realname, pass, waitUntilMode,
			connectCommands sql.NullString
		var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
		err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
			&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
			&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
			&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled,
			&waitUntilMode)
		if err != nil {
			return nil, err
		}
@@ -308,6 +313,7 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network
		net.Username = username.String
		net.Realname = realname.String
		net.Pass = pass.String
		net.WaitUntilMode = userModes(waitUntilMode.String)
		if connectCommands.Valid {
			net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
		}
@@ -355,23 +361,23 @@ func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *N
		err = db.db.QueryRowContext(ctx, `
			INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
				sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
				sasl_external_key, enabled)
			VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
				sasl_external_key, enabled, wait_until_mode)
			VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
			RETURNING id`,
			userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
			saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
			network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
			network.SASL.External.PrivKeyBlob, network.Enabled, network.WaitUntilMode).Scan(&network.ID)
	} else {
		_, err = db.db.ExecContext(ctx, `
			UPDATE "Network"
			SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
				connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
				sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
				enabled = $14
				enabled = $14, wait_until_mode = $15
			WHERE id = $1`,
			network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
			saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
			network.SASL.External.PrivKeyBlob, network.Enabled)
			network.SASL.External.PrivKeyBlob, network.Enabled, network.WaitUntilMode)
	}
	return err
}
diff --git a/db_sqlite.go b/db_sqlite.go
index 89c9478..44bb9ed 100644
--- a/db_sqlite.go
+++ b/db_sqlite.go
@@ -34,6 +34,7 @@ CREATE TABLE Network (
	username TEXT,
	realname TEXT,
	pass TEXT,
	wait_until_mode TEXT,
	connect_commands TEXT,
	sasl_mechanism TEXT,
	sasl_plain_username TEXT,
@@ -189,6 +190,7 @@ var sqliteMigrations = []string{
			UNIQUE(network, target)
		);
	`,
	"ALTER TABLE Network ADD COLUMN wait_until_mode TEXT",
}

type SqliteDB struct {
@@ -447,7 +449,7 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
	rows, err := db.db.QueryContext(ctx, `
		SELECT id, name, addr, nick, username, realname, pass,
			connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
			sasl_external_cert, sasl_external_key, enabled
			sasl_external_cert, sasl_external_key, enabled, wait_until_mode
		FROM Network
		WHERE user = ?`,
		userID)
@@ -459,11 +461,13 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
	var networks []Network
	for rows.Next() {
		var net Network
		var name, nick, username, realname, pass, connectCommands sql.NullString
		var name, nick, username, realname, pass, waitUntilMode,
			connectCommands sql.NullString
		var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
		err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
			&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
			&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
			&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled,
			&waitUntilMode)
		if err != nil {
			return nil, err
		}
@@ -472,6 +476,7 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
		net.Username = username.String
		net.Realname = realname.String
		net.Pass = pass.String
		net.WaitUntilMode = userModes(waitUntilMode.String)
		if connectCommands.Valid {
			net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
		}
@@ -524,6 +529,7 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net
		sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
		sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
		sql.Named("enabled", network.Enabled),
		sql.Named("wait_until_mode", network.WaitUntilMode),

		sql.Named("id", network.ID), // only for UPDATE
		sql.Named("user", userID),   // only for INSERT
@@ -537,17 +543,19 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net
				realname = :realname, pass = :pass, connect_commands = :connect_commands,
				sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
				sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
				enabled = :enabled
				enabled = :enabled, wait_until_mode = :wait_until_mode
			WHERE id = :id`, args...)
	} else {
		var res sql.Result
		res, err = db.db.ExecContext(ctx, `
			INSERT INTO Network(user, name, addr, nick, username, realname, pass,
				connect_commands, sasl_mechanism, sasl_plain_username,
				sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
				sasl_plain_password, sasl_external_cert, sasl_external_key,
				enabled, wait_until_mode)
			VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
				:connect_commands, :sasl_mechanism, :sasl_plain_username,
				:sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
				:sasl_plain_password, :sasl_external_cert, :sasl_external_key,
				:enabled, :wait_until_mode)`,
			args...)
		if err != nil {
			return err
diff --git a/doc/soju.1.scd b/doc/soju.1.scd
index 68b19fe..7e680ca 100644
--- a/doc/soju.1.scd
+++ b/doc/soju.1.scd
@@ -224,6 +224,12 @@ abbreviated form, for instance *network* can be abbreviated as *net* or just
		Enable or disable the network. If the network is disabled, the bouncer
		won't connect to it. By default, the network is enabled.

	*-wait-until-mode* <user-modes>
		Wait for the user modes to be set until joining network's channels.
		This can be used together with -connect-command to not get kicked out
		of channels on servers that don't support SASL which allow only
		registered users to join.

	*-connect-command* <command>
		Send the specified command as a raw IRC message right after connecting
		to the server. This can be used to identify to an account when the
diff --git a/irc.go b/irc.go
index 9f417a0..6bd760c 100644
--- a/irc.go
+++ b/irc.go
@@ -39,6 +39,16 @@ func formatServerTime(t time.Time) string {
	return t.UTC().Format(serverTimeLayout)
}

func isValidMode(mode byte) bool {
	if mode >= byte('A') && mode <= byte('Z') {
		return true
	}
	if mode >= byte('a') && mode <= byte('z') {
		return true
	}
	return false
}

type userModes string

func (ms userModes) Has(c byte) bool {
diff --git a/service.go b/service.go
index 2a94ee9..95986ca 100644
--- a/service.go
+++ b/service.go
@@ -199,7 +199,7 @@ func init() {
		"network": {
			children: serviceCommandSet{
				"create": {
					usage:  "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
					usage:  "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-wait-until-mode usermodes] [-connect-command command]...",
					desc:   "add a new network",
					handle: handleServiceNetworkCreate,
				},
@@ -208,7 +208,7 @@ func init() {
					handle: handleServiceNetworkStatus,
				},
				"update": {
					usage:  "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
					usage:  "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-wait-until-mode usermodes] [-connect-command command]...",
					desc:   "update a network",
					handle: handleServiceNetworkUpdate,
				},
@@ -429,9 +429,9 @@ func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string,

type networkFlagSet struct {
	*flag.FlagSet
	Addr, Name, Nick, Username, Pass, Realname *string
	Enabled                                    *bool
	ConnectCommands                            []string
	Addr, Name, Nick, Username, Pass, Realname, WaitUntilMode *string
	Enabled                                                   *bool
	ConnectCommands                                           []string
}

func newNetworkFlagSet() *networkFlagSet {
@@ -443,6 +443,7 @@ func newNetworkFlagSet() *networkFlagSet {
	fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
	fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
	fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "")
	fs.Var(stringPtrFlag{&fs.WaitUntilMode}, "wait-until-mode", "")
	fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
	return fs
}
@@ -480,19 +481,13 @@ func (fs *networkFlagSet) update(network *Network) error {
	if fs.Enabled != nil {
		network.Enabled = *fs.Enabled
	}
	if fs.WaitUntilMode != nil {
		network.WaitUntilMode = userModes(*fs.WaitUntilMode)
	}
	if fs.ConnectCommands != nil {
		if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
			network.ConnectCommands = nil
		} else {
			if len(fs.ConnectCommands) > 20 {
				return fmt.Errorf("too many -connect-command flags supplied")
			}
			for _, command := range fs.ConnectCommands {
				_, err := irc.ParseMessage(command)
				if err != nil {
					return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
				}
			}
			network.ConnectCommands = fs.ConnectCommands
		}
	}
diff --git a/upstream.go b/upstream.go
index 56d604e..c52c238 100644
--- a/upstream.go
+++ b/upstream.go
@@ -119,20 +119,21 @@ type upstreamConn struct {
	availableMemberships  []membership
	isupport              map[string]*string

	registered  bool
	nick        string
	nickCM      string
	username    string
	realname    string
	hostname    string
	modes       userModes
	channels    upstreamChannelCasemapMap
	caps        capRegistry
	batches     map[string]batch
	away        bool
	account     string
	nextLabelID uint64
	monitored   monitorCasemapMap
	registered     bool
	joinedChannels bool
	nick           string
	nickCM         string
	username       string
	realname       string
	hostname       string
	modes          userModes
	channels       upstreamChannelCasemapMap
	caps           capRegistry
	batches        map[string]batch
	away           bool
	account        string
	nextLabelID    uint64
	monitored      monitorCasemapMap

	saslClient  sasl.Client
	saslStarted bool
@@ -744,17 +745,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
		uc.nickCM = uc.network.casemap(uc.nick)
		uc.logger.Printf("connection registered with nick %q", uc.nick)

		if uc.network.channels.Len() > 0 {
			var channels, keys []string
			for _, entry := range uc.network.channels.innerMap {
				ch := entry.value.(*Channel)
				channels = append(channels, ch.Name)
				keys = append(keys, ch.Key)
			}

			for _, msg := range join(channels, keys) {
				uc.SendMessage(ctx, msg)
			}
		if len([]byte(uc.network.WaitUntilMode)) == 0 {
			uc.joinChannels(ctx)
		}
	case irc.RPL_MYINFO:
		if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
@@ -1123,6 +1115,20 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err

				dc.SendMessage(msg)
			})

			if len([]byte(uc.network.WaitUntilMode)) > 0 {
				matches := true
				for _, mode := range []byte(uc.network.WaitUntilMode) {
					if !uc.modes.Has(mode) {
						matches = false
						break
					}
				}

				if matches {
					uc.joinChannels(ctx)
				}
			}
		} else { // channel mode change
			ch, err := uc.getChannel(name)
			if err != nil {
@@ -1946,6 +1952,25 @@ func splitSpace(s string) []string {
	})
}

func (uc *upstreamConn) joinChannels(ctx context.Context) {
	if uc.joinedChannels {
		return
	}
	if uc.network.channels.Len() > 0 {
		var channels, keys []string
		for _, entry := range uc.network.channels.innerMap {
			ch := entry.value.(*Channel)
			channels = append(channels, ch.Name)
			keys = append(keys, ch.Key)
		}

		for _, msg := range join(channels, keys) {
			uc.SendMessage(ctx, msg)
		}
	}
	uc.joinedChannels = true
}

func (uc *upstreamConn) register(ctx context.Context) {
	uc.nick = GetNick(&uc.user.User, &uc.network.Network)
	uc.nickCM = uc.network.casemap(uc.nick)
diff --git a/user.go b/user.go
index bf1909f..2e0ecce 100644
--- a/user.go
+++ b/user.go
@@ -857,6 +857,25 @@ func (u *user) checkNetwork(record *Network) error {
		}
	}

	if record.WaitUntilMode != "" {
		for _, mode := range []byte(record.WaitUntilMode) {
			if !isValidMode(mode) {
				return fmt.Errorf("user modes specified in wait-until-mode must be Latin characters (a-zA-Z)")
			}
		}
	}

	if len(record.ConnectCommands) > 20 {
		return fmt.Errorf("too many connect commands supplied")
	} else if len(record.ConnectCommands) > 0 {
		for _, command := range record.ConnectCommands {
			_, err := irc.ParseMessage(command)
			if err != nil {
				return fmt.Errorf("connect commands must be valid raw irc command strings: %q: %v", command, err)
			}
		}
	}

	return nil
}

-- 
2.35.1
This sounds like a potentially acceptable workaround, but I'd like to
make sure this isn't only useful for OFTC. Do other non-SASL servers
in widespread use also set a user mode when NickServ kicks in?