package parse import ( "fmt" "go/ast" "go/parser" "go/token" "os" "reflect" "sort" "strings" "github.com/tinylib/msgp/gen" "github.com/ttacon/chalk" ) // A FileSet is the in-memory representation of a // parsed file. type FileSet struct { Package string // package name Specs map[string]ast.Expr // type specs in file Identities map[string]gen.Elem // processed from specs Directives []string // raw preprocessor directives Imports []*ast.ImportSpec // imports } // File parses a file at the relative path // provided and produces a new *FileSet. // If you pass in a path to a directory, the entire // directory will be parsed. // If unexport is false, only exported identifiers are included in the FileSet. // If the resulting FileSet would be empty, an error is returned. func File(name string, unexported bool) (*FileSet, error) { pushstate(name) defer popstate() fs := &FileSet{ Specs: make(map[string]ast.Expr), Identities: make(map[string]gen.Elem), } fset := token.NewFileSet() finfo, err := os.Stat(name) if err != nil { return nil, err } if finfo.IsDir() { pkgs, err := parser.ParseDir(fset, name, nil, parser.ParseComments) if err != nil { return nil, err } if len(pkgs) != 1 { return nil, fmt.Errorf("multiple packages in directory: %s", name) } var one *ast.Package for _, nm := range pkgs { one = nm break } fs.Package = one.Name for _, fl := range one.Files { pushstate(fl.Name.Name) fs.Directives = append(fs.Directives, yieldComments(fl.Comments)...) if !unexported { ast.FileExports(fl) } fs.getTypeSpecs(fl) popstate() } } else { f, err := parser.ParseFile(fset, name, nil, parser.ParseComments) if err != nil { return nil, err } fs.Package = f.Name.Name fs.Directives = yieldComments(f.Comments) if !unexported { ast.FileExports(f) } fs.getTypeSpecs(f) } if len(fs.Specs) == 0 { return nil, fmt.Errorf("no definitions in %s", name) } fs.process() fs.applyDirectives() fs.propInline() return fs, nil } // applyDirectives applies all of the directives that // are known to the parser. additional method-specific // directives remain in f.Directives func (f *FileSet) applyDirectives() { newdirs := make([]string, 0, len(f.Directives)) for _, d := range f.Directives { chunks := strings.Split(d, " ") if len(chunks) > 0 { if fn, ok := directives[chunks[0]]; ok { pushstate(chunks[0]) err := fn(chunks, f) if err != nil { warnln(err.Error()) } popstate() } else { newdirs = append(newdirs, d) } } } f.Directives = newdirs } // A linkset is a graph of unresolved // identities. // // Since gen.Ident can only represent // one level of type indirection (e.g. Foo -> uint8), // type declarations like `type Foo Bar` // aren't resolve-able until we've processed // everything else. // // The goal of this dependency resolution // is to distill the type declaration // into just one level of indirection. // In other words, if we have: // // type A uint64 // type B A // type C B // type D C // // ... then we want to end up // figuring out that D is just a uint64. type linkset map[string]*gen.BaseElem func (f *FileSet) resolve(ls linkset) { progress := true for progress && len(ls) > 0 { progress = false for name, elem := range ls { real, ok := f.Identities[elem.TypeName()] if ok { // copy the old type descriptor, // alias it to the new value, // and insert it into the resolved // identities list progress = true nt := real.Copy() nt.Alias(name) f.Identities[name] = nt delete(ls, name) } } } // what's left can't be resolved for name, elem := range ls { warnf("couldn't resolve type %s (%s)\n", name, elem.TypeName()) } } // process takes the contents of f.Specs and // uses them to populate f.Identities func (f *FileSet) process() { deferred := make(linkset) parse: for name, def := range f.Specs { pushstate(name) el := f.parseExpr(def) if el == nil { warnln("failed to parse") popstate() continue parse } // push unresolved identities into // the graph of links and resolve after // we've handled every possible named type. if be, ok := el.(*gen.BaseElem); ok && be.Value == gen.IDENT { deferred[name] = be popstate() continue parse } el.Alias(name) f.Identities[name] = el popstate() } if len(deferred) > 0 { f.resolve(deferred) } } func strToMethod(s string) gen.Method { switch s { case "encode": return gen.Encode case "decode": return gen.Decode case "test": return gen.Test case "size": return gen.Size case "marshal": return gen.Marshal case "unmarshal": return gen.Unmarshal default: return 0 } } func (f *FileSet) applyDirs(p *gen.Printer) { // apply directives of the form // // //msgp:encode ignore {{TypeName}} // loop: for _, d := range f.Directives { chunks := strings.Split(d, " ") if len(chunks) > 1 { for i := range chunks { chunks[i] = strings.TrimSpace(chunks[i]) } m := strToMethod(chunks[0]) if m == 0 { warnf("unknown pass name: %q\n", chunks[0]) continue loop } if fn, ok := passDirectives[chunks[1]]; ok { pushstate(chunks[1]) err := fn(m, chunks[2:], p) if err != nil { warnf("error applying directive: %s\n", err) } popstate() } else { warnf("unrecognized directive %q\n", chunks[1]) } } else { warnf("empty directive: %q\n", d) } } } func (f *FileSet) PrintTo(p *gen.Printer) error { f.applyDirs(p) names := make([]string, 0, len(f.Identities)) for name := range f.Identities { names = append(names, name) } sort.Strings(names) for _, name := range names { el := f.Identities[name] el.SetVarname("z") pushstate(el.TypeName()) err := p.Print(el) popstate() if err != nil { return err } } return nil } // getTypeSpecs extracts all of the *ast.TypeSpecs in the file // into fs.Identities, but does not set the actual element func (fs *FileSet) getTypeSpecs(f *ast.File) { // collect all imports... fs.Imports = append(fs.Imports, f.Imports...) // check all declarations... for i := range f.Decls { // for GenDecls... if g, ok := f.Decls[i].(*ast.GenDecl); ok { // and check the specs... for _, s := range g.Specs { // for ast.TypeSpecs.... if ts, ok := s.(*ast.TypeSpec); ok { switch ts.Type.(type) { // this is the list of parse-able // type specs case *ast.StructType, *ast.ArrayType, *ast.StarExpr, *ast.MapType, *ast.Ident: fs.Specs[ts.Name.Name] = ts.Type } } } } } } func fieldName(f *ast.Field) string { switch len(f.Names) { case 0: return stringify(f.Type) case 1: return f.Names[0].Name default: return f.Names[0].Name + " (and others)" } } func (fs *FileSet) parseFieldList(fl *ast.FieldList) []gen.StructField { if fl == nil || fl.NumFields() == 0 { return nil } out := make([]gen.StructField, 0, fl.NumFields()) for _, field := range fl.List { pushstate(fieldName(field)) fds := fs.getField(field) if len(fds) > 0 { out = append(out, fds...) } else { warnln("ignored.") } popstate() } return out } // translate *ast.Field into []gen.StructField func (fs *FileSet) getField(f *ast.Field) []gen.StructField { sf := make([]gen.StructField, 1) var extension bool // parse tag; otherwise field name is field tag if f.Tag != nil { body := reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get("msg") tags := strings.Split(body, ",") if len(tags) == 2 && tags[1] == "extension" { extension = true } // ignore "-" fields if tags[0] == "-" { return nil } sf[0].FieldTag = tags[0] } ex := fs.parseExpr(f.Type) if ex == nil { return nil } // parse field name switch len(f.Names) { case 0: sf[0].FieldName = embedded(f.Type) case 1: sf[0].FieldName = f.Names[0].Name default: // this is for a multiple in-line declaration, // e.g. type A struct { One, Two int } sf = sf[0:0] for _, nm := range f.Names { sf = append(sf, gen.StructField{ FieldTag: nm.Name, FieldName: nm.Name, FieldElem: ex.Copy(), }) } return sf } sf[0].FieldElem = ex if sf[0].FieldTag == "" { sf[0].FieldTag = sf[0].FieldName } // validate extension if extension { switch ex := ex.(type) { case *gen.Ptr: if b, ok := ex.Value.(*gen.BaseElem); ok { b.Value = gen.Ext } else { warnln("couldn't cast to extension.") return nil } case *gen.BaseElem: ex.Value = gen.Ext default: warnln("couldn't cast to extension.") return nil } } return sf } // extract embedded field name // // so, for a struct like // // type A struct { // io.Writer // } // // we want "Writer" func embedded(f ast.Expr) string { switch f := f.(type) { case *ast.Ident: return f.Name case *ast.StarExpr: return embedded(f.X) case *ast.SelectorExpr: return f.Sel.Name default: // other possibilities are disallowed return "" } } // stringify a field type name func stringify(e ast.Expr) string { switch e := e.(type) { case *ast.Ident: return e.Name case *ast.StarExpr: return "*" + stringify(e.X) case *ast.SelectorExpr: return stringify(e.X) + "." + e.Sel.Name case *ast.ArrayType: if e.Len == nil { return "[]" + stringify(e.Elt) } return fmt.Sprintf("[%s]%s", stringify(e.Len), stringify(e.Elt)) case *ast.InterfaceType: if e.Methods == nil || e.Methods.NumFields() == 0 { return "interface{}" } } return "" } // recursively translate ast.Expr to gen.Elem; nil means type not supported // expected input types: // - *ast.MapType (map[T]J) // - *ast.Ident (name) // - *ast.ArrayType ([(sz)]T) // - *ast.StarExpr (*T) // - *ast.StructType (struct {}) // - *ast.SelectorExpr (a.B) // - *ast.InterfaceType (interface {}) func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { switch e := e.(type) { case *ast.MapType: if k, ok := e.Key.(*ast.Ident); ok && k.Name == "string" { if in := fs.parseExpr(e.Value); in != nil { return &gen.Map{Value: in} } } return nil case *ast.Ident: b := gen.Ident(e.Name) // work to resove this expression // can be done later, once we've resolved // everything else. if b.Value == gen.IDENT { if _, ok := fs.Specs[e.Name]; !ok { warnf("non-local identifier: %s\n", e.Name) } } return b case *ast.ArrayType: // special case for []byte if e.Len == nil { if i, ok := e.Elt.(*ast.Ident); ok && i.Name == "byte" { return &gen.BaseElem{Value: gen.Bytes} } } // return early if we don't know // what the slice element type is els := fs.parseExpr(e.Elt) if els == nil { return nil } // array and not a slice if e.Len != nil { switch s := e.Len.(type) { case *ast.BasicLit: return &gen.Array{ Size: s.Value, Els: els, } case *ast.Ident: return &gen.Array{ Size: s.String(), Els: els, } case *ast.SelectorExpr: return &gen.Array{ Size: stringify(s), Els: els, } default: return nil } } return &gen.Slice{Els: els} case *ast.StarExpr: if v := fs.parseExpr(e.X); v != nil { return &gen.Ptr{Value: v} } return nil case *ast.StructType: if fields := fs.parseFieldList(e.Fields); len(fields) > 0 { return &gen.Struct{Fields: fields} } return nil case *ast.SelectorExpr: return gen.Ident(stringify(e)) case *ast.InterfaceType: // support `interface{}` if len(e.Methods.List) == 0 { return &gen.BaseElem{Value: gen.Intf} } return nil default: // other types not supported return nil } } func infof(s string, v ...interface{}) { pushstate(s) fmt.Printf(chalk.Green.Color(strings.Join(logctx, ": ")), v...) popstate() } func infoln(s string) { pushstate(s) fmt.Println(chalk.Green.Color(strings.Join(logctx, ": "))) popstate() } func warnf(s string, v ...interface{}) { pushstate(s) fmt.Printf(chalk.Yellow.Color(strings.Join(logctx, ": ")), v...) popstate() } func warnln(s string) { pushstate(s) fmt.Println(chalk.Yellow.Color(strings.Join(logctx, ": "))) popstate() } func fatalf(s string, v ...interface{}) { pushstate(s) fmt.Printf(chalk.Red.Color(strings.Join(logctx, ": ")), v...) popstate() } var logctx []string // push logging state func pushstate(s string) { logctx = append(logctx, s) } // pop logging state func popstate() { logctx = logctx[:len(logctx)-1] }