This commit adds a new field to networks named WaitUntilMode. This field
is consulted on RPL_WELCOME numeric message and MODE message. If
WaitUntilMode is not empty, soju doesn't join channels on RPL_WELCOME.
Instead, on any MODE message that touches our user, this field is
consulted. If user's modes contain all modes from WaitUntilMode, soju
proceeds to join the channels.
As an example, users of networks that do not support SASL (or somehow
otherwise authenticate the user only after connecting) can set
WaitUntilMode to "R" (registered) and still allow soju to join channels
that only allow registered users to join.
---
Feel free to NACK this patch. I couldn't come up with a better solution,
except for a -wait-until-server-sends-message-that-matches
".*You are successfully identified as .*" which is much messier.
Also, of IRC servers I lurk on, only libera.chat (Solanum) doesn't set
any registration-related flags on the user. OFTC (oftc-hybrid, fork of
ircd-hybrid) (duh), Rizon (plexus, fork of ircd-hybrid) and tilde.chat
(InspIRCd) do.
db.go | 1 +
db_postgres.go | 22 +++++++++------
db_sqlite.go | 20 ++++++++++----
doc/soju.1.scd | 6 ++++
irc.go | 10 +++++++
service.go | 23 ++++++----------
upstream.go | 75 +++++++++++++++++++++++++++++++++-----------------
user.go | 19 +++++++++++++
8 files changed, 123 insertions(+), 53 deletions(-)
diff --git a/db.go b/db.go
index 6dcab77..f202609 100644
--- a/db.go
+++ b/db.go
@@ -88,6 +88,7 @@ type Network struct {
Username string
Realname string
Pass string
+ WaitUntilMode userModes
ConnectCommands []string
SASL SASL
Enabled bool
diff --git a/db_postgres.go b/db_postgres.go
index 8833adf..15615b6 100644
--- a/db_postgres.go
+++ b/db_postgres.go
@@ -44,6 +44,7 @@ CREATE TABLE "Network" (
username VARCHAR(255),
realname VARCHAR(255),
pass VARCHAR(255),
+ wait_until_mode VARCHAR(31),
connect_commands VARCHAR(1023),
sasl_mechanism sasl_mechanism,
sasl_plain_username VARCHAR(255),
@@ -106,6 +107,7 @@ var postgresMigrations = []string{
UNIQUE(network, target)
);
`,
+ `ALTER TABLE "Network" ADD COLUMN wait_until_mode VARCHAR(31)`,
}
type PostgresDB struct {
@@ -284,7 +286,8 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network
rows, err := db.db.QueryContext(ctx, `
SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
- sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
+ sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key,
+ enabled, wait_until_mode
FROM "Network"
WHERE "user" = $1`, userID)
if err != nil {
@@ -295,11 +298,13 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network
var networks []Network
for rows.Next() {
var net Network
- var name, nick, username, realname, pass, connectCommands sql.NullString
+ var name, nick, username, realname, pass, waitUntilMode,
+ connectCommands sql.NullString
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
- &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled,
+ &waitUntilMode)
if err != nil {
return nil, err
}
@@ -308,6 +313,7 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network
net.Username = username.String
net.Realname = realname.String
net.Pass = pass.String
+ net.WaitUntilMode = userModes(waitUntilMode.String)
if connectCommands.Valid {
net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
}
@@ -355,23 +361,23 @@ func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *N
err = db.db.QueryRowContext(ctx, `
INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
- sasl_external_key, enabled)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
+ sasl_external_key, enabled, wait_until_mode)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
RETURNING id`,
userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
- network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
+ network.SASL.External.PrivKeyBlob, network.Enabled, network.WaitUntilMode).Scan(&network.ID)
} else {
_, err = db.db.ExecContext(ctx, `
UPDATE "Network"
SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
- enabled = $14
+ enabled = $14, wait_until_mode = $15
WHERE id = $1`,
network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
- network.SASL.External.PrivKeyBlob, network.Enabled)
+ network.SASL.External.PrivKeyBlob, network.Enabled, network.WaitUntilMode)
}
return err
}
diff --git a/db_sqlite.go b/db_sqlite.go
index 89c9478..44bb9ed 100644
--- a/db_sqlite.go
+++ b/db_sqlite.go
@@ -34,6 +34,7 @@ CREATE TABLE Network (
username TEXT,
realname TEXT,
pass TEXT,
+ wait_until_mode TEXT,
connect_commands TEXT,
sasl_mechanism TEXT,
sasl_plain_username TEXT,
@@ -189,6 +190,7 @@ var sqliteMigrations = []string{
UNIQUE(network, target)
);
`,
+ "ALTER TABLE Network ADD COLUMN wait_until_mode TEXT",
}
type SqliteDB struct {
@@ -447,7 +449,7 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
rows, err := db.db.QueryContext(ctx, `
SELECT id, name, addr, nick, username, realname, pass,
connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
- sasl_external_cert, sasl_external_key, enabled
+ sasl_external_cert, sasl_external_key, enabled, wait_until_mode
FROM Network
WHERE user = ?`,
userID)
@@ -459,11 +461,13 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
var networks []Network
for rows.Next() {
var net Network
- var name, nick, username, realname, pass, connectCommands sql.NullString
+ var name, nick, username, realname, pass, waitUntilMode,
+ connectCommands sql.NullString
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
- &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled,
+ &waitUntilMode)
if err != nil {
return nil, err
}
@@ -472,6 +476,7 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
net.Username = username.String
net.Realname = realname.String
net.Pass = pass.String
+ net.WaitUntilMode = userModes(waitUntilMode.String)
if connectCommands.Valid {
net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
}
@@ -524,6 +529,7 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net
sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
sql.Named("enabled", network.Enabled),
+ sql.Named("wait_until_mode", network.WaitUntilMode),
sql.Named("id", network.ID), // only for UPDATE
sql.Named("user", userID), // only for INSERT
@@ -537,17 +543,19 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net
realname = :realname, pass = :pass, connect_commands = :connect_commands,
sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
- enabled = :enabled
+ enabled = :enabled, wait_until_mode = :wait_until_mode
WHERE id = :id`, args...)
} else {
var res sql.Result
res, err = db.db.ExecContext(ctx, `
INSERT INTO Network(user, name, addr, nick, username, realname, pass,
connect_commands, sasl_mechanism, sasl_plain_username,
- sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
+ sasl_plain_password, sasl_external_cert, sasl_external_key,
+ enabled, wait_until_mode)
VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
:connect_commands, :sasl_mechanism, :sasl_plain_username,
- :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
+ :sasl_plain_password, :sasl_external_cert, :sasl_external_key,
+ :enabled, :wait_until_mode)`,
args...)
if err != nil {
return err
diff --git a/doc/soju.1.scd b/doc/soju.1.scd
index 68b19fe..7e680ca 100644
--- a/doc/soju.1.scd
+++ b/doc/soju.1.scd
@@ -224,6 +224,12 @@ abbreviated form, for instance *network* can be abbreviated as *net* or just
Enable or disable the network. If the network is disabled, the bouncer
won't connect to it. By default, the network is enabled.
+ *-wait-until-mode* <user-modes>
+ Wait for the user modes to be set until joining network's channels.
+ This can be used together with -connect-command to not get kicked out
+ of channels on servers that don't support SASL which allow only
+ registered users to join.
+
*-connect-command* <command>
Send the specified command as a raw IRC message right after connecting
to the server. This can be used to identify to an account when the
diff --git a/irc.go b/irc.go
index 9f417a0..6bd760c 100644
--- a/irc.go
+++ b/irc.go
@@ -39,6 +39,16 @@ func formatServerTime(t time.Time) string {
return t.UTC().Format(serverTimeLayout)
}
+func isValidMode(mode byte) bool {
+ if mode >= byte('A') && mode <= byte('Z') {
+ return true
+ }
+ if mode >= byte('a') && mode <= byte('z') {
+ return true
+ }
+ return false
+}
+
type userModes string
func (ms userModes) Has(c byte) bool {
diff --git a/service.go b/service.go
index 2a94ee9..95986ca 100644
--- a/service.go
+++ b/service.go
@@ -199,7 +199,7 @@ func init() {
"network": {
children: serviceCommandSet{
"create": {
- usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-wait-until-mode usermodes] [-connect-command command]...",
desc: "add a new network",
handle: handleServiceNetworkCreate,
},
@@ -208,7 +208,7 @@ func init() {
handle: handleServiceNetworkStatus,
},
"update": {
- usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-wait-until-mode usermodes] [-connect-command command]...",
desc: "update a network",
handle: handleServiceNetworkUpdate,
},
@@ -429,9 +429,9 @@ func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string,
type networkFlagSet struct {
*flag.FlagSet
- Addr, Name, Nick, Username, Pass, Realname *string
- Enabled *bool
- ConnectCommands []string
+ Addr, Name, Nick, Username, Pass, Realname, WaitUntilMode *string
+ Enabled *bool
+ ConnectCommands []string
}
func newNetworkFlagSet() *networkFlagSet {
@@ -443,6 +443,7 @@ func newNetworkFlagSet() *networkFlagSet {
fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "")
+ fs.Var(stringPtrFlag{&fs.WaitUntilMode}, "wait-until-mode", "")
fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
return fs
}
@@ -480,19 +481,13 @@ func (fs *networkFlagSet) update(network *Network) error {
if fs.Enabled != nil {
network.Enabled = *fs.Enabled
}
+ if fs.WaitUntilMode != nil {
+ network.WaitUntilMode = userModes(*fs.WaitUntilMode)
+ }
if fs.ConnectCommands != nil {
if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
network.ConnectCommands = nil
} else {
- if len(fs.ConnectCommands) > 20 {
- return fmt.Errorf("too many -connect-command flags supplied")
- }
- for _, command := range fs.ConnectCommands {
- _, err := irc.ParseMessage(command)
- if err != nil {
- return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
- }
- }
network.ConnectCommands = fs.ConnectCommands
}
}
diff --git a/upstream.go b/upstream.go
index 56d604e..c52c238 100644
--- a/upstream.go
+++ b/upstream.go
@@ -119,20 +119,21 @@ type upstreamConn struct {
availableMemberships []membership
isupport map[string]*string
- registered bool
- nick string
- nickCM string
- username string
- realname string
- hostname string
- modes userModes
- channels upstreamChannelCasemapMap
- caps capRegistry
- batches map[string]batch
- away bool
- account string
- nextLabelID uint64
- monitored monitorCasemapMap
+ registered bool
+ joinedChannels bool
+ nick string
+ nickCM string
+ username string
+ realname string
+ hostname string
+ modes userModes
+ channels upstreamChannelCasemapMap
+ caps capRegistry
+ batches map[string]batch
+ away bool
+ account string
+ nextLabelID uint64
+ monitored monitorCasemapMap
saslClient sasl.Client
saslStarted bool
@@ -744,17 +745,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.nickCM = uc.network.casemap(uc.nick)
uc.logger.Printf("connection registered with nick %q", uc.nick)
- if uc.network.channels.Len() > 0 {
- var channels, keys []string
- for _, entry := range uc.network.channels.innerMap {
- ch := entry.value.(*Channel)
- channels = append(channels, ch.Name)
- keys = append(keys, ch.Key)
- }
-
- for _, msg := range join(channels, keys) {
- uc.SendMessage(ctx, msg)
- }
+ if len([]byte(uc.network.WaitUntilMode)) == 0 {
+ uc.joinChannels(ctx)
}
case irc.RPL_MYINFO:
if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
@@ -1123,6 +1115,20 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
dc.SendMessage(msg)
})
+
+ if len([]byte(uc.network.WaitUntilMode)) > 0 {
+ matches := true
+ for _, mode := range []byte(uc.network.WaitUntilMode) {
+ if !uc.modes.Has(mode) {
+ matches = false
+ break
+ }
+ }
+
+ if matches {
+ uc.joinChannels(ctx)
+ }
+ }
} else { // channel mode change
ch, err := uc.getChannel(name)
if err != nil {
@@ -1946,6 +1952,25 @@ func splitSpace(s string) []string {
})
}
+func (uc *upstreamConn) joinChannels(ctx context.Context) {
+ if uc.joinedChannels {
+ return
+ }
+ if uc.network.channels.Len() > 0 {
+ var channels, keys []string
+ for _, entry := range uc.network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, ch.Name)
+ keys = append(keys, ch.Key)
+ }
+
+ for _, msg := range join(channels, keys) {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+ uc.joinedChannels = true
+}
+
func (uc *upstreamConn) register(ctx context.Context) {
uc.nick = GetNick(&uc.user.User, &uc.network.Network)
uc.nickCM = uc.network.casemap(uc.nick)
diff --git a/user.go b/user.go
index bf1909f..2e0ecce 100644
--- a/user.go
+++ b/user.go
@@ -857,6 +857,25 @@ func (u *user) checkNetwork(record *Network) error {
}
}
+ if record.WaitUntilMode != "" {
+ for _, mode := range []byte(record.WaitUntilMode) {
+ if !isValidMode(mode) {
+ return fmt.Errorf("user modes specified in wait-until-mode must be Latin characters (a-zA-Z)")
+ }
+ }
+ }
+
+ if len(record.ConnectCommands) > 20 {
+ return fmt.Errorf("too many connect commands supplied")
+ } else if len(record.ConnectCommands) > 0 {
+ for _, command := range record.ConnectCommands {
+ _, err := irc.ParseMessage(command)
+ if err != nil {
+ return fmt.Errorf("connect commands must be valid raw irc command strings: %q: %v", command, err)
+ }
+ }
+ }
+
return nil
}
--
2.35.1