~sircmpwn/public-inbox

go-bare: Refactor cmd/gen v1 PROPOSED

Timofey: 1
 Refactor cmd/gen

 3 files changed, 296 insertions(+), 302 deletions(-)
#317151 .build.yml success
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~sircmpwn/public-inbox/patches/14242/mbox | git am -3
Learn more about email & git
View this thread in the archives

[PATCH go-bare] Refactor cmd/gen Export this patch

* Use text/template instead of manually writing to output file
* Format generated code with gofmt
---
 cmd/gen/gen.go    | 228 ---------------------------------------
 cmd/gen/main.go   | 269 +++++++++++++++++++++++++++++++++++++++++-----
 example/schema.go | 101 ++++++++---------
 3 files changed, 296 insertions(+), 302 deletions(-)
 delete mode 100644 cmd/gen/gen.go

diff --git a/cmd/gen/gen.go b/cmd/gen/gen.go
deleted file mode 100644
index bdd0381..0000000
--- a/cmd/gen/gen.go
@@ -1,228 +0,0 @@
package main

import (
	"fmt"
	"io"
	"strings"

	"git.sr.ht/~sircmpwn/go-bare/schema"
)

type Context struct {
	unions       []*schema.UserDefinedType
	unionMembers map[schema.Type]interface{}
}

func genTypes(w io.Writer, types []schema.SchemaType) {
	fmt.Fprintf(w, `
// THIS FILE WAS GENERATED BY A TOOL, DO NOT EDIT

import (
	"errors"
	"git.sr.ht/~sircmpwn/go-bare"
)
`)

	ctx := Context{
		unions:       nil,
		unionMembers: make(map[schema.Type]interface{}),
	}
	for _, ty := range types {
		switch ty := ty.(type) {
		case *schema.UserDefinedType:
			ctx.genUserType(w, ty)
		case *schema.UserDefinedEnum:
			ctx.genUserEnum(w, ty)
		}
	}

	if len(ctx.unions) > 0 {
		fmt.Fprintf(w, "\nfunc init() {\n")
		for _, udt := range ctx.unions {
			fmt.Fprintf(w, "\tbare.RegisterUnion((*%s)(nil)).\n", udt.Name())
			ut, _ := udt.Type().(*schema.UnionType)
			for i, ty := range ut.Types() {
				tag := ty.Tag()
				switch ty := ty.Type().(type) {
				case *schema.NamedUserType:
					fmt.Fprintf(w, "\t\tMember(*new(%s), %d)", ty.Name(), tag)
				default:
					panic(fmt.Errorf("TODO: Implement unions with primitive types"))
				}
				if i < len(ut.Types()) - 1 {
					fmt.Fprintf(w, ".\n")
				}
			}
			fmt.Fprintf(w, "\n")
		}
		fmt.Fprintf(w, "}\n")
	}
}

func (ctx *Context) genUserType(w io.Writer, udt *schema.UserDefinedType) {
	if udt.Type().Kind() == schema.Union {
		ctx.genUserUnion(w, udt)
		return
	}

	fmt.Fprintf(w, "\ntype %s ", udt.Name())
	genType(w, udt.Type(), 0)
	fmt.Fprintf(w, "\n")

	fmt.Fprintf(w, "\nfunc (t *%s) Decode(data []byte) error {", udt.Name())
	fmt.Fprintf(w, "\n\treturn bare.Unmarshal(data, t)")
	fmt.Fprintf(w, "\n}\n")

	fmt.Fprintf(w, "\nfunc (t *%s) Encode() ([]byte, error) {", udt.Name())
	fmt.Fprintf(w, "\n\treturn bare.Marshal(t)")
	fmt.Fprintf(w, "\n}\n")
}

func (ctx *Context) genUserEnum(w io.Writer, ude *schema.UserDefinedEnum) {
	// TODO: Disambiguate between enums with conflicting value names
	fmt.Fprintf(w, "\ntype %s %s\n", ude.Name(), primitiveType(ude.Kind()))
	fmt.Fprintf(w, "\nconst (")
	for i, val := range ude.Values() {
		if i == 0 {
			fmt.Fprintf(w, "\n\t%s %s = %d", val.Name(), ude.Name(), val.Value())
		} else {
			fmt.Fprintf(w, "\n\t%s = %d", val.Name(), val.Value())
		}
	}
	fmt.Fprintf(w, "\n)\n")

	fmt.Fprintf(w, "\nfunc (t %s) String() string {", ude.Name())
	fmt.Fprintf(w, "\n\tswitch (t) {")
	for _, val := range ude.Values() {
		fmt.Fprintf(w, "\n\tcase %s:", val.Name())
		fmt.Fprintf(w, "\n\t\treturn \"%s\"", val.Name())
	}
	fmt.Fprintf(w, "\n\t}")
	fmt.Fprintf(w, "\n\tpanic(errors.New(\"Invalid %s value\"))", ude.Name())
	fmt.Fprintf(w, "\n}\n")
}

func (ctx *Context) genUserUnion(w io.Writer, udt *schema.UserDefinedType) {
	fmt.Fprintf(w, "\ntype %s interface {", udt.Name())
	fmt.Fprintf(w, "\n\tbare.Union")
	fmt.Fprintf(w, "\n}\n")

	ut, _ := udt.Type().(*schema.UnionType)
	for _, ty := range ut.Types() {
		// XXX: This doesn't actually work the way it looks like it ought to
		if _, ok := ctx.unionMembers[ty.Type()]; ok {
			continue
		}

		ctx.unionMembers[ty.Type()] = nil

		switch ty := ty.Type().(type) {
		case *schema.NamedUserType:
			fmt.Fprintf(w, "\nfunc (_ %s) IsUnion() { }\n", ty.Name())
		default:
			panic(fmt.Errorf("TODO: Implement unions with primitive types"))
		}
	}

	ctx.unions = append(ctx.unions, udt)
}

func genType(w io.Writer, ty schema.Type, indent int) {
	switch ty := ty.(type) {
	case *schema.PrimitiveType:
		fmt.Fprintf(w, "%s", primitiveType(ty.Kind()))
	case *schema.DataType:
		if ty.Kind() == schema.DataArray {
			fmt.Fprintf(w, "[%d]byte", ty.Length())
		} else {
			fmt.Fprintf(w, "[]byte")
		}
	case *schema.StructType:
		maxName := 0
		for _, field := range ty.Fields() {
			if len(field.Name()) > maxName {
				maxName = len(field.Name())
			}
		}

		fmt.Fprintf(w, "struct {\n")
		for _, field := range ty.Fields() {
			genIndent(w, indent + 1)
			n := fieldName(field.Name())
			fmt.Fprintf(w, "%s ", n)
			for i := len(n); i < maxName; i++ {
				fmt.Fprintf(w, " ")
			}
			genType(w, field.Type(), indent + 1)
			fmt.Fprintf(w, " `bare:\"%s\"`", field.Name())
			fmt.Fprintf(w, "\n")
		}
		genIndent(w, indent)
		fmt.Fprintf(w, "}")
	case *schema.NamedUserType:
		fmt.Fprintf(w, "%s", ty.Name())
	case *schema.MapType:
		fmt.Fprintf(w, "map[")
		genType(w, ty.Key(), indent)
		fmt.Fprintf(w, "]")
		genType(w, ty.Value(), indent)
	case *schema.ArrayType:
		if ty.Kind() == schema.Array {
			fmt.Fprintf(w, "[%d]", ty.Length())
		} else {
			fmt.Fprintf(w, "[]")
		}
		genType(w, ty.Member(), indent)
	case *schema.OptionalType:
		fmt.Fprintf(w, "*")
		genType(w, ty.Subtype(), indent)
	default:
		panic(fmt.Errorf("Unimplemented schema type: %T", ty))
	}
}

func genUnion(w io.Writer, ut *schema.UnionType, indent int) {
}

func primitiveType(kind schema.TypeKind) string {
	switch kind {
	case schema.U8:
		return "uint8"
	case schema.U16:
		return "uint16"
	case schema.U32:
		return "uint32"
	case schema.U64:
		return "uint64"
	case schema.I8:
		return "int8"
	case schema.I16:
		return "int16"
	case schema.I32:
		return "int32"
	case schema.I64:
		return "int64"
	case schema.F32:
		return "float32"
	case schema.F64:
		return "float64"
	case schema.Bool:
		return "bool"
	case schema.String:
		return "string"
	case schema.Void:
		return "struct{}"
	}
	panic(fmt.Errorf("Invalid primitive type %d", kind))
}

func genIndent(w io.Writer, indent int) {
	for ; indent > 0; indent-- {
		fmt.Fprintf(w, "\t")
	}
}

func fieldName(n string) string {
	// TODO: Correct initialisms
	return strings.ToUpper(n[:1]) + n[1:]
}
diff --git a/cmd/gen/main.go b/cmd/gen/main.go
index c8e88dd..9d2e106 100644
--- a/cmd/gen/main.go
+++ b/cmd/gen/main.go
@@ -1,29 +1,235 @@
package main

import (
	"bytes"
	"fmt"
	"go/format"
	"io/ioutil"
	"log"
	"os"
	"strings"
	"text/template"

	"git.sr.ht/~sircmpwn/getopt"

	"git.sr.ht/~sircmpwn/go-bare/schema"
)

const templateString = `
package {{.package}}

// Code generated by go-bare/cmd/gen, DO NOT EDIT.

import (
	"errors"
	"git.sr.ht/~sircmpwn/go-bare"
)

{{ define "type" }}
	{{- if eq (typeKind .) "PrimitiveType"  -}}
		{{ primitiveType .Kind }}
	{{- else if eq (typeKind .) "DataType" -}}
		[{{if gt .Length 0 }}{{.Length}}{{end}}]byte
	{{- else if eq (typeKind .) "ArrayType" -}}
		[{{if gt .Length 0 }}{{.Length}}{{end}}]{{template "type" .Member}}
	{{- else if eq (typeKind .) "StructType" -}}
		struct {
			{{- range .Fields }}
				{{ capitalize .Name }} {{ template "type" .Type }} {{ structTag .Name }}
			{{- end -}}
		}
	{{- else if eq (typeKind .) "NamedUserType" -}}
		{{.Name}}
	{{- else if eq (typeKind .) "MapType" -}}
		map[{{template "type" .Key}}]{{template "type" .Value}}
	{{- else if eq (typeKind .) "OptionalType" -}}
		*{{template "type" .Subtype}}
	{{- end -}}
{{ end }}

{{with .schema}}

{{range .UserTypes}}
	type {{ .Name }} {{ template "type" .Type }}

	func (t *{{ .Name }}) Decode(data []byte) error {
		return bare.Unmarshal(data, t)
	}

	func (t *{{ .Name }}) Encode() ([]byte, error) {
		return bare.Marshal(t)
	}
{{end}}

{{range .Enums}}
type {{ .Name }} {{ primitiveType .Kind }}

{{ $name := .Name }}

const (
		{{- range $i, $el := .Values }}
			{{ .Name }} {{ $name }} = {{ .Value }}
		{{- end -}}
	)

	func (t {{ .Name }}) String() string {
		switch (t) {
		{{- range .Values }}
		case {{ .Name }}:
			return "{{ .Name }}"
		{{- end -}}
		}
		panic(errors.New("Invalid {{.Name}} value"))
	}
{{end}}

{{ if gt (len .Unions) 0 }}
	{{range .Unions}}
		type {{ .Name }} interface {
			bare.Union
		}

		{{range .Type.Types}}
			func (_ {{.Type.Name}}) IsUnion() {}
		{{end}}
	{{end}}

	func init() {
		{{- range .Unions}}
		bare.RegisterUnion((*{{.Name}})(nil)).
			{{ $len := len .Type.Types }}
			{{range $i, $el := .Type.Types}}
				Member(*new({{ template "type" $el.Type}}), {{$el.Tag}}){{- if not (last $len $i) -}}.{{end}}
			{{end}}
		{{ end }}
	}
{{ end}}

{{end}}
`

var funcs = template.FuncMap{
	"typeKind": func(ty interface{}) string {
		switch ty := ty.(type) {
		case *schema.PrimitiveType:
			return "PrimitiveType"
		case *schema.DataType:
			return "DataType"
		case *schema.StructType:
			return "StructType"
		case *schema.NamedUserType:
			return "NamedUserType"
		case *schema.MapType:
			return "MapType"
		case *schema.ArrayType:
			return "ArrayType"
		case *schema.OptionalType:
			return "OptionalType"
		default:
			panic(fmt.Sprintf("Unimplemented schema type: %T", ty))
		}
	},
	"primitiveType": func(t schema.TypeKind) string {
		switch t {
		case schema.U8:
			return "uint8"
		case schema.U16:
			return "uint16"
		case schema.U32:
			return "uint32"
		case schema.U64:
			return "uint64"
		case schema.I8:
			return "int8"
		case schema.I16:
			return "int16"
		case schema.I32:
			return "int32"
		case schema.I64:
			return "int64"
		case schema.F32:
			return "float32"
		case schema.F64:
			return "float64"
		case schema.Bool:
			return "bool"
		case schema.String:
			return "string"
		case schema.Void:
			return "struct{}"
		}
		panic(fmt.Errorf("Invalid primitive type %d", t))
	},
	"structTag": func(name string) string {
		return fmt.Sprintf("`bare:\"%s\"`", name)
	},
	"capitalize": func(s string) string {
		return strings.ToUpper(s[:1]) + s[1:]
	},
	"last": func(len, i int) bool {
		return i+1 == len
	},
}

func main() {
	cfg := parseArgs()
	out := &bytes.Buffer{}

	tmpl, err := template.New("").Funcs(funcs).Parse(templateString)
	if err != nil {
		log.Fatalf("error parsing template: %v", err)
	}

	types := parseSchema(cfg.In, cfg.Skip)

	data := make(map[string]interface{})

	data["package"] = cfg.PackageName
	data["schema"] = types

	err = tmpl.Execute(out, data)
	if err != nil {
		log.Fatalf("error executing template: %v", err)
	}

	// Format generated code
	formatted, err := format.Source(out.Bytes())
	if err != nil {
		log.Println(out.String())
		log.Fatalf("--- error formatting source code: %v", err)
	}

	err = ioutil.WriteFile(cfg.Out, formatted, 0644)
	if err != nil {
		log.Fatalf("error writing output to %s: %e", cfg.Out, err)
	}
}

type Config struct {
	PackageName string
	In          string
	Out         string
	Skip        map[string]bool
}

func parseArgs() *Config {
	cfg := &Config{}

	log.SetFlags(0)
	opts, optind, err := getopt.Getopts(os.Args, "hs:p:")
	if err != nil {
		log.Fatalf("error: %e", err)
	}
	pkg := "gen"
	skip := make(map[string]interface{})

	cfg.PackageName = "gen"
	cfg.Skip = make(map[string]bool)

	for _, opt := range opts {
		switch opt.Option {
		case 'p':
			pkg = opt.Value
			cfg.PackageName = opt.Value
		case 's':
			skip[opt.Value] = nil
			cfg.Skip[opt.Value] = true
		case 'h':
			log.Println("Usage: gen [-p <package>] [-s <skip type>] <input.bare> <output.go>")
			os.Exit(0)
@@ -34,36 +240,51 @@ func main() {
	if len(args) != 2 {
		log.Fatal("Usage: gen [-p <package>] <input.bare> <output.go>")
	}
	in := args[0]
	out := args[1]

	inf, err := os.Open(in)
	cfg.In = args[0]
	cfg.Out = args[1]

	return cfg
}

type Types struct {
	UserTypes []*schema.UserDefinedType
	Enums     []*schema.UserDefinedEnum
	Unions    []*schema.UserDefinedType
}

func parseSchema(path string, skip map[string]bool) Types {
	inf, err := os.Open(path)
	if err != nil {
		log.Fatalf("error opening %s: %e", in, err)
		log.Fatalf("error opening %s: %e", path, err)
	}
	defer inf.Close()

	types, err := schema.Parse(inf)
	schemaTypes, err := schema.Parse(inf)
	if err != nil {
		log.Fatalf("error parsing %s: %e", in, err)
		log.Fatalf("error parsing %s: %e", path, err)
	}

	outf, err := os.Create(out)
	if err != nil {
		log.Fatalf("error opening %s for writing: %e", out, err)
	}
	defer outf.Close()
	fmt.Fprintf(outf, "package %s\n", pkg)

	if len(skip) != 0 {
		var typesp []schema.SchemaType
		for _, ty := range types {
			if _, ok := skip[ty.Name()]; !ok {
				typesp = append(typesp, ty)
	types := Types{}

	for _, ty := range schemaTypes {
		if skip[ty.Name()] {
			continue
		}

		switch ty := ty.(type) {
		case *schema.UserDefinedType:
			if ty.Type().Kind() == schema.Union {
				types.Unions = append(types.Unions, ty)
				continue
			}
			types.UserTypes = append(types.UserTypes, ty)

		case *schema.UserDefinedEnum:
			types.Enums = append(types.Enums, ty)

		}
		types = typesp
	}

	genTypes(outf, types)
	return types
}
diff --git a/example/schema.go b/example/schema.go
index bfc1987..b73cab7 100644
--- a/example/schema.go
+++ b/example/schema.go
@@ -1,6 +1,6 @@
package example

// THIS FILE WAS GENERATED BY A TOOL, DO NOT EDIT
// Code generated by go-bare/cmd/gen, DO NOT EDIT.

import (
	"errors"
@@ -17,37 +17,11 @@ func (t *PublicKey) Encode() ([]byte, error) {
	return bare.Marshal(t)
}

type Department uint8

const (
	ACCOUNTING Department = 0
	ADMINISTRATION = 1
	CUSTOMER_SERVICE = 2
	DEVELOPMENT = 3
	JSMITH = 99
)

func (t Department) String() string {
	switch (t) {
	case ACCOUNTING:
		return "ACCOUNTING"
	case ADMINISTRATION:
		return "ADMINISTRATION"
	case CUSTOMER_SERVICE:
		return "CUSTOMER_SERVICE"
	case DEVELOPMENT:
		return "DEVELOPMENT"
	case JSMITH:
		return "JSMITH"
	}
	panic(errors.New("Invalid Department value"))
}

type Customer struct {
	Name     string `bare:"name"`
	Email    string `bare:"email"`
	Address  Address `bare:"address"`
	Orders   []struct {
	Name    string  `bare:"name"`
	Email   string  `bare:"email"`
	Address Address `bare:"address"`
	Orders  []struct {
		OrderId  int64 `bare:"orderId"`
		Quantity int32 `bare:"quantity"`
	} `bare:"orders"`
@@ -63,12 +37,12 @@ func (t *Customer) Encode() ([]byte, error) {
}

type Employee struct {
	Name       string `bare:"name"`
	Email      string `bare:"email"`
	Address    Address `bare:"address"`
	Department Department `bare:"department"`
	HireDate   Time `bare:"hireDate"`
	PublicKey  *PublicKey `bare:"publicKey"`
	Name       string            `bare:"name"`
	Email      string            `bare:"email"`
	Address    Address           `bare:"address"`
	Department Department        `bare:"department"`
	HireDate   Time              `bare:"hireDate"`
	PublicKey  *PublicKey        `bare:"publicKey"`
	Metadata   map[string][]byte `bare:"metadata"`
}

@@ -90,21 +64,11 @@ func (t *TerminatedEmployee) Encode() ([]byte, error) {
	return bare.Marshal(t)
}

type Person interface {
	bare.Union
}

func (_ Customer) IsUnion() { }

func (_ Employee) IsUnion() { }

func (_ TerminatedEmployee) IsUnion() { }

type Address struct {
	Address [4]string `bare:"address"`
	City    string `bare:"city"`
	State   string `bare:"state"`
	Country string `bare:"country"`
	City    string    `bare:"city"`
	State   string    `bare:"state"`
	Country string    `bare:"country"`
}

func (t *Address) Decode(data []byte) error {
@@ -115,9 +79,46 @@ func (t *Address) Encode() ([]byte, error) {
	return bare.Marshal(t)
}

type Department uint8

const (
	ACCOUNTING       Department = 0
	ADMINISTRATION   Department = 1
	CUSTOMER_SERVICE Department = 2
	DEVELOPMENT      Department = 3
	JSMITH           Department = 99
)

func (t Department) String() string {
	switch t {
	case ACCOUNTING:
		return "ACCOUNTING"
	case ADMINISTRATION:
		return "ADMINISTRATION"
	case CUSTOMER_SERVICE:
		return "CUSTOMER_SERVICE"
	case DEVELOPMENT:
		return "DEVELOPMENT"
	case JSMITH:
		return "JSMITH"
	}
	panic(errors.New("Invalid Department value"))
}

type Person interface {
	bare.Union
}

func (_ Customer) IsUnion() {}

func (_ Employee) IsUnion() {}

func (_ TerminatedEmployee) IsUnion() {}

func init() {
	bare.RegisterUnion((*Person)(nil)).
		Member(*new(Customer), 0).
		Member(*new(Employee), 1).
		Member(*new(TerminatedEmployee), 2)

}
-- 
2.28.0
builds.sr.ht
go-bare/patches/.build.yml: SUCCESS in 35s

[Refactor cmd/gen][0] from [Timofey][1]

[0]: https://lists.sr.ht/~sircmpwn/public-inbox/patches/14242
[1]: mailto:chikker@cock.li

✓ #317151 SUCCESS go-bare/patches/.build.yml https://builds.sr.ht/~sircmpwn/job/317151