~emersion/soju-dev

Add support for the upstream echo-message capability v2 SUPERSEDED

delthas: 1
 Add support for the upstream echo-message capability

 3 files changed, 96 insertions(+), 83 deletions(-)
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~emersion/soju-dev/patches/29238/mbox | git am -3
Learn more about email & git

[PATCH v2] Add support for the upstream echo-message capability Export this patch

This adds support for upstream echo-message. This capability is only
enabled when all downstreams support it.

When it is enabled, we don't echo downstream messages in the downstream
handler, but rather wait for the upstream to echo it, to produce it to
all downstreams.

When it is disabled, we keep the same behaviour as before: produce the
message to all downstreams as soon as it is received from the
downstream.

In other words, the main functional difference is that when all
upstreams support echo-message, the client will now receive an echo for
its messages when the server acknowledges them, rather than when soju
acks them.

Additionally, the downstream PRIVMSG/NOTICE/TAGMSG handler was slightly
refactored into a common switch case as there was starting to be a lot
of common code.
---
 downstream.go | 139 ++++++++++++++++++++++----------------------------
 upstream.go   |  38 +++++++++++---
 user.go       |   2 +
 3 files changed, 96 insertions(+), 83 deletions(-)

diff --git a/downstream.go b/downstream.go
index b5caa8c..fb06bcd 100644
--- a/downstream.go
+++ b/downstream.go
@@ -909,6 +909,10 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {

		if !dc.registered {
			dc.negotiatingCaps = true
		} else {
			dc.forEachUpstream(func(uc *upstreamConn) {
				uc.updateCaps()
			})
		}
	case "END":
		dc.negotiatingCaps = false
@@ -2304,15 +2308,25 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
			Command: "WHOIS",
			Params:  params,
		})
	case "PRIVMSG", "NOTICE":
	case "PRIVMSG", "NOTICE", "TAGMSG":
		tag := msg.Command == "TAGMSG"
		var targetsStr, text string
		if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
		if err := parseMessageParams(msg, &targetsStr); err != nil {
			return err
		}
		if !tag {
			if err := parseMessageParams(msg, nil, &text); err != nil {
				return err
			}
		}
		tags := copyClientTags(msg.Tags)

		for _, name := range strings.Split(targetsStr, ",") {
			if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) {
			msgParams := []string{name}
			if !tag {
				msgParams = append(msgParams, text)
			}
			if !tag && (name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil)) {
				// "$" means a server mask follows. If it's the bouncer's
				// hostname, broadcast the message to all bouncer users.
				if !dc.user.Admin {
@@ -2331,7 +2345,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
					Tags:    broadcastTags,
					Prefix:  servicePrefix,
					Command: msg.Command,
					Params:  []string{name, text},
					Params:  msgParams,
				}
				dc.srv.forEachUser(func(u *user) {
					u.events <- eventBroadcast{broadcastMsg}
@@ -2344,12 +2358,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
					Tags:    msg.Tags.Copy(),
					Prefix:  dc.prefix(),
					Command: msg.Command,
					Params:  []string{name, text},
					Params:  msgParams,
				})
				continue
			}

			if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM {
			if msg.Command != "NOTICE" && casemapASCII(name) == serviceNickCM {
				if dc.caps["echo-message"] {
					echoTags := tags.Copy()
					echoTags["time"] = irc.TagValue(time.Now().UTC().Format(serverTimeLayout))
@@ -2357,10 +2371,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
						Tags:    echoTags,
						Prefix:  dc.prefix(),
						Command: msg.Command,
						Params:  []string{name, text},
						Params:  msgParams,
					})
				}
				handleServicePRIVMSG(ctx, dc, text)
				if !tag {
					handleServicePRIVMSG(ctx, dc, text)
				}
				continue
			}

@@ -2368,84 +2384,53 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
			if err != nil {
				return err
			}
			msgParams[0] = upstreamName

			if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" {
				dc.handleNickServPRIVMSG(ctx, uc, text)
			if tag {
				if _, ok := uc.caps["message-tags"]; !ok {
					continue
				}
			}

			unmarshaledText := text
			if uc.isChannel(upstreamName) {
				unmarshaledText = dc.unmarshalText(uc, text)
			}
			uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
				Tags:    tags,
				Command: msg.Command,
				Params:  []string{upstreamName, unmarshaledText},
			})

			echoTags := tags.Copy()
			echoTags["time"] = irc.TagValue(time.Now().UTC().Format(serverTimeLayout))
			if uc.account != "" {
				echoTags["account"] = irc.TagValue(uc.account)
			}
			echoMsg := &irc.Message{
				Tags:    echoTags,
				Prefix:  &irc.Prefix{Name: uc.nick},
				Command: msg.Command,
				Params:  []string{upstreamName, text},
			if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" {
				dc.handleNickServPRIVMSG(ctx, uc, text)
			}
			uc.produce(upstreamName, echoMsg, dc)

			uc.updateChannelAutoDetach(upstreamName)
		}
	case "TAGMSG":
		var targetsStr string
		if err := parseMessageParams(msg, &targetsStr); err != nil {
			return err
		}
		tags := copyClientTags(msg.Tags)

		for _, name := range strings.Split(targetsStr, ",") {
			if dc.network == nil && casemapASCII(name) == dc.nickCM {
				dc.SendMessage(&irc.Message{
					Tags:    msg.Tags.Copy(),
					Prefix:  dc.prefix(),
					Command: "TAGMSG",
					Params:  []string{name},
			if !tag {
				unmarshaledText := text
				if uc.isChannel(upstreamName) {
					unmarshaledText = dc.unmarshalText(uc, text)
				}
				uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
					Tags:    tags,
					Command: msg.Command,
					Params:  []string{upstreamName, unmarshaledText},
				})
			} else {
				uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
					Tags:    tags,
					Command: msg.Command,
					Params:  []string{upstreamName},
				})
				continue
			}

			if casemapASCII(name) == serviceNickCM {
				continue
			}

			uc, upstreamName, err := dc.unmarshalEntity(name)
			if err != nil {
				return err
			}
			if _, ok := uc.caps["message-tags"]; !ok {
				continue
			}

			uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
				Tags:    tags,
				Command: "TAGMSG",
				Params:  []string{upstreamName},
			})

			echoTags := tags.Copy()
			echoTags["time"] = irc.TagValue(time.Now().UTC().Format(serverTimeLayout))
			if uc.account != "" {
				echoTags["account"] = irc.TagValue(uc.account)
			}
			echoMsg := &irc.Message{
				Tags:    echoTags,
				Prefix:  &irc.Prefix{Name: uc.nick},
				Command: "TAGMSG",
				Params:  []string{upstreamName},
			// If the upstream supports echo message, we'll produce the mesasge
			// when it is echoed from the upstream.
			// Otherwise, produce/log it here because it's the last time we'll see it.
			if !uc.caps["echo-message"] {
				echoTags := tags.Copy()
				echoTags["time"] = irc.TagValue(time.Now().UTC().Format(serverTimeLayout))
				if uc.account != "" {
					echoTags["account"] = irc.TagValue(uc.account)
				}
				echoMsg := &irc.Message{
					Tags:    echoTags,
					Prefix:  &irc.Prefix{Name: uc.nick},
					Command: msg.Command,
					Params:  msgParams,
				}
				uc.produce(upstreamName, echoMsg, dc)
			}
			uc.produce(upstreamName, echoMsg, dc)

			uc.updateChannelAutoDetach(upstreamName)
		}
diff --git a/upstream.go b/upstream.go
index 81509d6..11a4746 100644
--- a/upstream.go
+++ b/upstream.go
@@ -39,6 +39,12 @@ var permanentUpstreamCaps = map[string]bool{
	"draft/extended-monitor":     true,
}

// needAllUpstreamCaps is the list of upstream capabilities that
// require support from all downstreams to be enabled
var needAllUpstreamCaps = map[string]bool{
	"echo-message": true,
}

type registrationError struct {
	*irc.Message
}
@@ -488,8 +494,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
				target = msg.Prefix.Name
			}

			self := uc.isOurNick(msg.Prefix.Name)

			ch := uc.network.channels.Value(target)
			if ch != nil && msg.Command != "TAGMSG" {
			if ch != nil && msg.Command != "TAGMSG" && !self {
				if ch.Detached {
					uc.handleDetachedMessage(ctx, ch, msg)
				}
@@ -523,7 +531,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
				break // wait to receive all capabilities
			}

			uc.requestCaps()
			uc.updateCaps()

			if uc.requestSASL() {
				break // we'll send CAP END after authentication is completed
@@ -540,7 +548,14 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
			caps := strings.Fields(subParams[0])

			for _, name := range caps {
				if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil {
				var enable bool
				if strings.HasPrefix(name, "-") {
					name = strings.TrimPrefix(name, "-")
					enable = false
				} else {
					enable = subCmd == "ACK"
				}
				if err := uc.handleCapAck(ctx, strings.ToLower(name), enable); err != nil {
					return err
				}
			}
@@ -555,7 +570,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
				return newNeedMoreParamsError(msg.Command)
			}
			uc.handleSupportedCaps(subParams[0])
			uc.requestCaps()
			uc.updateCaps()
		case "DEL":
			if len(subParams) < 1 {
				return newNeedMoreParamsError(msg.Command)
@@ -1828,13 +1843,24 @@ func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
	}
}

func (uc *upstreamConn) requestCaps() {
func (uc *upstreamConn) updateCaps() {
	var requestCaps []string
	for c := range permanentUpstreamCaps {
		if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] {
			requestCaps = append(requestCaps, c)
		}
	}
	for c := range needAllUpstreamCaps {
		enabled := true
		uc.forEachDownstream(func(dc *downstreamConn) {
			enabled = enabled && dc.caps[c]
		})
		if !uc.caps[c] && enabled {
			requestCaps = append(requestCaps, c)
		} else if uc.caps[c] && !enabled {
			requestCaps = append(requestCaps, "-"+c)
		}
	}

	if len(requestCaps) == 0 {
		return
@@ -1902,7 +1928,7 @@ func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool)
			Params:  []string{auth.Mechanism},
		})
	default:
		if permanentUpstreamCaps[name] {
		if permanentUpstreamCaps[name] || needAllUpstreamCaps[name] {
			break
		}
		uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
diff --git a/user.go b/user.go
index fefe654..f66c301 100644
--- a/user.go
+++ b/user.go
@@ -646,6 +646,7 @@ func (u *user) run() {

			u.forEachUpstream(func(uc *upstreamConn) {
				uc.updateAway()
				uc.updateCaps()
			})
		case eventDownstreamDisconnected:
			dc := e.dc
@@ -663,6 +664,7 @@ func (u *user) run() {

			u.forEachUpstream(func(uc *upstreamConn) {
				uc.cancelPendingCommandsByDownstreamID(dc.id)
				uc.updateCaps()
				uc.updateAway()
				uc.updateMonitor()
			})

base-commit: 47dfba466c711ee8df8ead370e13fba8175e5d2b
-- 
2.17.1