~samwhited/mellium-devel

xmpp: xmpp: make STARTTLS always required v1 APPLIED

Sam Whited: 1
 xmpp: make STARTTLS always required

 8 files changed, 76 insertions(+), 83 deletions(-)
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~samwhited/mellium-devel/patches/10529/mbox | git am -3
Learn more about email & git
View this thread in the archives

[PATCH xmpp] xmpp: make STARTTLS always required Export this patch

TLS (or at the time, SSL) may have been an optional feature in the past,
but it's not anymore. These days it's far more likely that a server will
always want to require TLS in some form, so giving the user the ability
to turn it off just means we're giving users who won't understand the
consequences of their actions a knob to twiddle. In the very rare case
that a user actually *does* need STARTTLS to be an optional stream
feature, I don't think it's something we should support. For this rare
use case, they'll have to take the maintenance burden on themselves by
copy/pasting the StartTLS feature code and tweaking it for their needs.

Fixes #50

Signed-off-by: Sam Whited <sam@samwhited.com>
---
 CHANGELOG.md             |   5 ++
 echobot_example_test.go  |   2 +-
 examples/echobot/echo.go |   2 +-
 examples/im/main.go      |   2 +-
 examples/msgrepl/main.go |   2 +-
 session_test.go          |   2 +-
 starttls.go              |  20 +++----
 starttls_test.go         | 124 ++++++++++++++++++---------------------
 8 files changed, 76 insertions(+), 83 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index bf715a223b61..c46a464e21c1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.

## Unreleased

### Breaking

- xmpp: remove option to make STARTTLS feature optional


### Added

- xmpp: `ConnectionState` method
diff --git a/echobot_example_test.go b/echobot_example_test.go
index 458f7421934a..705f4a84e4eb 100644
--- a/echobot_example_test.go
+++ b/echobot_example_test.go
@@ -35,7 +35,7 @@ func Example_echobot() {
	s, err := xmpp.DialClientSession(
		context.TODO(), j,
		xmpp.BindResource(),
		xmpp.StartTLS(true, &tls.Config{
		xmpp.StartTLS(&tls.Config{
			ServerName: j.Domain().String(),
		}),
		xmpp.SASL("", pass, sasl.ScramSha1Plus, sasl.ScramSha1, sasl.Plain),
diff --git a/examples/echobot/echo.go b/examples/echobot/echo.go
index 6711b625142d..1232ebcc46af 100644
--- a/examples/echobot/echo.go
+++ b/examples/echobot/echo.go
@@ -42,7 +42,7 @@ func echo(ctx context.Context, addr, pass string, xmlIn, xmlOut io.Writer, logge
		Lang: "en",
		Features: []xmpp.StreamFeature{
			xmpp.BindResource(),
			xmpp.StartTLS(true, &tls.Config{
			xmpp.StartTLS(&tls.Config{
				ServerName: j.Domain().String(),
			}),
			xmpp.SASL("", pass, sasl.ScramSha1Plus, sasl.ScramSha1, sasl.Plain),
diff --git a/examples/im/main.go b/examples/im/main.go
index 47fdc0940810..0be981831082 100644
--- a/examples/im/main.go
+++ b/examples/im/main.go
@@ -129,7 +129,7 @@ func main() {
	session, err := xmpp.DialClientSession(
		dialCtx, parsedAddr,
		xmpp.BindResource(),
		xmpp.StartTLS(true, &tls.Config{
		xmpp.StartTLS(&tls.Config{
			ServerName: parsedAddr.Domain().String(),
		}),
		xmpp.SASL("", pass, sasl.ScramSha256Plus, sasl.ScramSha1Plus, sasl.ScramSha256, sasl.ScramSha1, sasl.Plain),
diff --git a/examples/msgrepl/main.go b/examples/msgrepl/main.go
index b5ff8ea06c8f..e75ce020feba 100644
--- a/examples/msgrepl/main.go
+++ b/examples/msgrepl/main.go
@@ -68,7 +68,7 @@ func main() {
	session, err := xmpp.DialClientSession(
		dialCtx, parsedAddr,
		xmpp.BindResource(),
		xmpp.StartTLS(true, &tls.Config{
		xmpp.StartTLS(&tls.Config{
			ServerName: parsedAddr.Domain().String(),
		}),
		xmpp.SASL("", pass, sasl.ScramSha1Plus, sasl.ScramSha1, sasl.Plain),
diff --git a/session_test.go b/session_test.go
index f7346dcaf1c8..8118ae4148e5 100644
--- a/session_test.go
+++ b/session_test.go
@@ -99,7 +99,7 @@ var negotiateTests = [...]negotiateTestCase{
	0: {negotiator: errNegotiator, err: errTestNegotiate},
	1: {
		negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{
			Features: []xmpp.StreamFeature{xmpp.StartTLS(true, nil)},
			Features: []xmpp.StreamFeature{xmpp.StartTLS(nil)},
		}),
		in:  `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'><stream:features><other/></stream:features>`,
		out: `<?xml version="1.0" encoding="UTF-8"?><stream:stream to='' from='' version='1.0' xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams'><starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`,
diff --git a/starttls.go b/starttls.go
index 189521cfd819..e6f58ed41dd7 100644
--- a/starttls.go
+++ b/starttls.go
@@ -19,24 +19,22 @@ import (
// StartTLS returns a new stream feature that can be used for negotiating TLS.
// If cfg is nil, a default configuration is used that uses the domainpart of
// the sessions local address as the ServerName.
func StartTLS(required bool, cfg *tls.Config) StreamFeature {
func StartTLS(cfg *tls.Config) StreamFeature {
	return StreamFeature{
		Name:       xml.Name{Local: "starttls", Space: ns.StartTLS},
		Prohibited: Secure,
		List: func(ctx context.Context, e xmlstream.TokenWriter, start xml.StartElement) (req bool, err error) {
			if err = e.EncodeToken(start); err != nil {
				return required, err
				return true, err
			}
			if required {
				startRequired := xml.StartElement{Name: xml.Name{Space: "", Local: "required"}}
				if err = e.EncodeToken(startRequired); err != nil {
					return required, err
				}
				if err = e.EncodeToken(startRequired.End()); err != nil {
					return required, err
				}
			startRequired := xml.StartElement{Name: xml.Name{Space: "", Local: "required"}}
			if err = e.EncodeToken(startRequired); err != nil {
				return true, err
			}
			return required, e.EncodeToken(start.End())
			if err = e.EncodeToken(startRequired.End()); err != nil {
				return true, err
			}
			return true, e.EncodeToken(start.End())
		},
		Parse: func(ctx context.Context, r xml.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {
diff --git a/starttls_test.go b/starttls_test.go
index fcfc8d5d575b..1d3f0fe2ab6f 100644
--- a/starttls_test.go
+++ b/starttls_test.go
@@ -22,76 +22,66 @@ import (
// There is no room for variation on the starttls feature negotiation, so step
// through the list process token for token.
func TestStartTLSList(t *testing.T) {
	for _, req := range []bool{true, false} {
		name := "optional"
		if req {
			name = "required"
		}
		t.Run(name, func(t *testing.T) {
			stls := xmpp.StartTLS(req, nil)
			var b bytes.Buffer
			e := xml.NewEncoder(&b)
			start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}
			r, err := stls.List(context.Background(), e, start)
			switch {
			case err != nil:
				t.Fatal(err)
			case r != req:
				t.Errorf("Expected StartTLS listing required to be %v but got %v", req, r)
			}
			if err = e.Flush(); err != nil {
				t.Fatal(err)
			}
	stls := xmpp.StartTLS(nil)
	var b bytes.Buffer
	e := xml.NewEncoder(&b)
	start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}
	r, err := stls.List(context.Background(), e, start)
	switch {
	case err != nil:
		t.Fatal(err)
	case !r:
		t.Error("Expected StartTLS listing to be required")
	}
	if err = e.Flush(); err != nil {
		t.Fatal(err)
	}

			d := xml.NewDecoder(&b)
			tok, err := d.Token()
			if err != nil {
				t.Fatal(err)
			}
			se := tok.(xml.StartElement)
			switch {
			case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
				t.Errorf("Expected starttls to start with %+v token but got %+v", ns.StartTLS, se.Name)
			case len(se.Attr) != 1:
				t.Errorf("Expected starttls start element to have 1 attribute (xmlns), but got %+v", se.Attr)
			}
			if req {
				tok, err = d.Token()
				if err != nil {
					t.Fatal(err)
				}
				se := tok.(xml.StartElement)
				switch {
				case se.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
					t.Errorf("Expected required start element but got %+v", se)
				case len(se.Attr) > 0:
					t.Errorf("Expected starttls required to have no attributes but got %d", len(se.Attr))
				}
				tok, err = d.Token()
				if err != nil {
					t.Fatal(err)
				}
				ee := tok.(xml.EndElement)
				switch {
				case se.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
					t.Errorf("Expected required end element but got %+v", ee)
				}
			}
			tok, err = d.Token()
			if err != nil {
				t.Fatal(err)
			}
			ee := tok.(xml.EndElement)
			switch {
			case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
				t.Errorf("Expected starttls end element but got %+v", ee)
			}
		})
	d := xml.NewDecoder(&b)
	tok, err := d.Token()
	if err != nil {
		t.Fatal(err)
	}
	se := tok.(xml.StartElement)
	switch {
	case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
		t.Errorf("Expected starttls to start with %+v token but got %+v", ns.StartTLS, se.Name)
	case len(se.Attr) != 1:
		t.Errorf("Expected starttls start element to have 1 attribute (xmlns), but got %+v", se.Attr)
	}
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	reqStart := tok.(xml.StartElement)
	switch {
	case reqStart.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
		t.Errorf("Expected required start element but got %+v", se)
	case len(reqStart.Attr) > 0:
		t.Errorf("Expected starttls required to have no attributes but got %d", len(reqStart.Attr))
	}
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	ee := tok.(xml.EndElement)
	switch {
	case reqStart.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
		t.Errorf("Expected required end element but got %+v", ee)
	}
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	ee = tok.(xml.EndElement)
	switch {
	case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
		t.Errorf("Expected starttls end element but got %+v", ee)
	}
}

func TestStartTLSParse(t *testing.T) {
	stls := xmpp.StartTLS(true, nil)
	stls := xmpp.StartTLS(nil)
	for i, test := range [...]struct {
		msg string
		req bool
@@ -131,7 +121,7 @@ func (nopRWC) Close() error {
}

func TestNegotiateServer(t *testing.T) {
	stls := xmpp.StartTLS(true, &tls.Config{})
	stls := xmpp.StartTLS(&tls.Config{})
	var b bytes.Buffer
	c := xmpptest.NewSession(xmpp.Received, nopRWC{&b, &b})
	_, rw, err := stls.Negotiate(context.Background(), c, nil)
@@ -169,7 +159,7 @@ func TestNegotiateClient(t *testing.T) {
		7: {[]string{`chardata not start element`}, true, false, 0},
	} {
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			stls := xmpp.StartTLS(true, &tls.Config{})
			stls := xmpp.StartTLS(&tls.Config{})
			r := strings.NewReader(strings.Join(test.responses, "\n"))
			var b bytes.Buffer
			c := xmpptest.NewSession(0, nopRWC{r, &b})
-- 
2.26.2