Skip to content

Commit c2d1048

Browse files
committed
Refactor
1 parent ec1e1a4 commit c2d1048

File tree

4 files changed

+230
-238
lines changed

4 files changed

+230
-238
lines changed

analyzer.go

Lines changed: 7 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
package wraperr
22

33
import (
4-
"bufio"
54
"go/ast"
65
"go/token"
7-
"go/types"
8-
"os"
9-
"strings"
106

117
"golang.org/x/tools/go/analysis"
128
"golang.org/x/tools/go/analysis/passes/inspect"
@@ -57,177 +53,22 @@ func init() {
5753
}
5854

5955
func runAnalyze(pass *analysis.Pass) (interface{}, error) {
60-
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
56+
r := newFileReader(pass.Fset)
6157

62-
isWrapped := func(call *ast.CallExpr) bool {
63-
switch fexpr := call.Fun.(type) {
64-
case *ast.SelectorExpr:
65-
if f, ok := pass.TypesInfo.ObjectOf(fexpr.Sel).(*types.Func); ok {
66-
chunks := []string{}
67-
for _, chunk := range strings.Split(f.FullName(), "/") {
68-
if chunk == "vendor" {
69-
chunks = []string{}
70-
} else {
71-
chunks = append(chunks, chunk)
72-
}
73-
}
74-
_, ok = wrapperFuncSet[strings.Join(chunks, "/")]
75-
return ok
76-
}
77-
}
78-
return false
58+
reportFunc := func(assignedAt, returnedAt token.Pos) {
59+
occPos := pass.Fset.Position(assignedAt)
60+
line := sprintInlineCode(r.GetLine(assignedAt))
61+
pass.Reportf(returnedAt, "the error is assigned on L%d: %s", occPos.Line, line)
7962
}
8063

64+
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
8165
nodeFilter := []ast.Node{
8266
(*ast.FuncDecl)(nil),
8367
}
8468

8569
inspect.Preorder(nodeFilter, func(n ast.Node) {
86-
f := n.(*ast.FuncDecl)
87-
if f.Type == nil || f.Type.Results == nil {
88-
return
89-
}
90-
91-
fields := f.Type.Results.List
92-
errInReturn := make([]bool, 0, 2*len(fields))
93-
errNames := make(map[string]struct{}, 2*len(fields))
94-
errIdents := map[string]*errIdent{}
95-
var ok bool
96-
for _, f := range fields {
97-
isErr := isErrorType(pass.TypesInfo.TypeOf(f.Type))
98-
ok = ok || isErr
99-
if len(f.Names) == 0 {
100-
errInReturn = append(errInReturn, isErr)
101-
}
102-
for _, n := range f.Names {
103-
errInReturn = append(errInReturn, isErr)
104-
errNames[n.Name] = struct{}{}
105-
}
106-
}
107-
if !ok {
108-
return
109-
}
110-
111-
lines := make(map[string][]string)
112-
113-
// ref: https://github.com/kisielk/errcheck/blob/1787c4bee836470bf45018cfbc783650db3c6501/internal/errcheck/errcheck.go#L488-L498
114-
getLine := func(tp token.Pos) string {
115-
pos := pass.Fset.Position(tp)
116-
foundLines, ok := lines[pos.Filename]
117-
118-
if !ok {
119-
f, err := os.Open(pos.Filename)
120-
if err == nil {
121-
sc := bufio.NewScanner(f)
122-
for sc.Scan() {
123-
foundLines = append(foundLines, sc.Text())
124-
}
125-
lines[pos.Filename] = foundLines
126-
f.Close()
127-
}
128-
}
129-
130-
line := "??"
131-
if pos.Line-1 < len(foundLines) {
132-
line = strings.TrimSpace(foundLines[pos.Line-1])
133-
}
134-
135-
return line
136-
}
137-
138-
recordUnwrappedError := func(occurredAt, returnedAt token.Pos) {
139-
occPos := pass.Fset.Position(occurredAt)
140-
pass.Reportf(returnedAt, "the error is assigned on L%d: %s", occPos.Line, sprintInlineCode(getLine(occurredAt)))
141-
}
142-
143-
ast.Inspect(f, func(n ast.Node) bool {
144-
switch stmt := n.(type) {
145-
case *ast.AssignStmt:
146-
var errIds []*ast.Ident
147-
for _, expr := range stmt.Lhs {
148-
if !isErrorType(pass.TypesInfo.TypeOf(expr)) {
149-
continue
150-
}
151-
if id, ok := expr.(*ast.Ident); ok {
152-
errIds = append(errIds, id)
153-
}
154-
}
155-
if len(errIds) > 0 {
156-
// Detect wrapped error assignment
157-
var wrapped bool
158-
for _, expr := range stmt.Rhs {
159-
if cexpr, ok := expr.(*ast.CallExpr); ok {
160-
if isWrapped(cexpr) {
161-
wrapped = true
162-
}
163-
}
164-
}
165-
for _, id := range errIds {
166-
errIdents[id.Name] = &errIdent{Ident: id, wrapped: wrapped}
167-
}
168-
}
169-
case *ast.ReturnStmt:
170-
switch len(stmt.Results) {
171-
case 0:
172-
// Named return values
173-
for n := range errNames {
174-
if errIdent, ok := errIdents[n]; ok && !errIdent.wrapped {
175-
recordUnwrappedError(errIdent.Pos(), stmt.Return)
176-
}
177-
}
178-
case len(errInReturn):
179-
// Simple return
180-
for i, expr := range stmt.Results {
181-
if !errInReturn[i] {
182-
continue
183-
}
184-
switch expr := expr.(type) {
185-
case *ast.Ident:
186-
if errIdent, ok := errIdents[expr.Name]; ok && !errIdent.wrapped {
187-
recordUnwrappedError(errIdent.Pos(), expr.NamePos)
188-
}
189-
case *ast.CallExpr:
190-
if !isWrapped(expr) {
191-
recordUnwrappedError(expr.Pos(), expr.Lparen)
192-
}
193-
default:
194-
// TODO: should report unexpected exper
195-
}
196-
}
197-
case 1:
198-
// Return another function directly
199-
switch expr := stmt.Results[0].(type) {
200-
case *ast.CallExpr:
201-
if !isWrapped(expr) {
202-
recordUnwrappedError(expr.Pos(), expr.Pos())
203-
}
204-
default:
205-
// TODO: should report unexpected exper
206-
}
207-
default:
208-
// TODO: should report unexpected exper
209-
}
210-
}
211-
return true
212-
})
70+
NewChecker(pass.Fset, pass.TypesInfo, n.(*ast.FuncDecl)).Check(reportFunc)
21371
})
21472

21573
return nil, nil
21674
}
217-
218-
func sprintInlineCode(s string) string {
219-
cc := 1
220-
c := cc
221-
for _, r := range s {
222-
if r == '`' {
223-
cc++
224-
if cc > c {
225-
c = cc
226-
}
227-
} else {
228-
cc = 1
229-
}
230-
}
231-
q := strings.Repeat("`", c)
232-
return q + s + q
233-
}

checker.go

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package wraperr
2+
3+
import (
4+
"go/ast"
5+
"go/token"
6+
"go/types"
7+
"strings"
8+
)
9+
10+
type Checker interface {
11+
Check(ReportFunc)
12+
}
13+
14+
type ReportFunc func(assignedAt, returnedAt token.Pos)
15+
16+
func NewChecker(
17+
fset *token.FileSet,
18+
info *types.Info,
19+
f *ast.FuncDecl,
20+
) Checker {
21+
return &checkerImpl{
22+
fset: fset,
23+
info: info,
24+
f: f,
25+
}
26+
}
27+
28+
type checkerImpl struct {
29+
fset *token.FileSet
30+
info *types.Info
31+
f *ast.FuncDecl
32+
33+
errInReturn []bool
34+
errNames map[string]struct{}
35+
errIdents map[string]*errIdent
36+
}
37+
38+
func (c *checkerImpl) Check(report ReportFunc) {
39+
if !c.init() {
40+
return
41+
}
42+
43+
ast.Inspect(c.f, func(n ast.Node) bool {
44+
switch stmt := n.(type) {
45+
case *ast.AssignStmt:
46+
c.checkAssignment(stmt)
47+
case *ast.ReturnStmt:
48+
c.checkReturn(stmt, report)
49+
}
50+
return true
51+
})
52+
}
53+
54+
func (c *checkerImpl) init() (ok bool) {
55+
if c.f.Type == nil || c.f.Type.Results == nil {
56+
return
57+
}
58+
59+
fields := c.f.Type.Results.List
60+
61+
c.errInReturn = make([]bool, 0, 2*len(fields))
62+
c.errNames = make(map[string]struct{}, 2*len(fields))
63+
c.errIdents = map[string]*errIdent{}
64+
65+
for _, f := range fields {
66+
isErr := isErrorType(c.info.TypeOf(f.Type))
67+
ok = ok || isErr
68+
if len(f.Names) == 0 {
69+
c.errInReturn = append(c.errInReturn, isErr)
70+
}
71+
for _, n := range f.Names {
72+
c.errInReturn = append(c.errInReturn, isErr)
73+
c.errNames[n.Name] = struct{}{}
74+
}
75+
}
76+
77+
return
78+
}
79+
80+
func (c *checkerImpl) checkAssignment(stmt *ast.AssignStmt) {
81+
var errIds []*ast.Ident
82+
for _, expr := range stmt.Lhs {
83+
if !isErrorType(c.info.TypeOf(expr)) {
84+
continue
85+
}
86+
if id, ok := expr.(*ast.Ident); ok {
87+
errIds = append(errIds, id)
88+
}
89+
}
90+
if len(errIds) > 0 {
91+
// Detect wrapped error assignment
92+
var wrapped bool
93+
for _, expr := range stmt.Rhs {
94+
if cexpr, ok := expr.(*ast.CallExpr); ok {
95+
if c.isWrapped(cexpr) {
96+
wrapped = true
97+
}
98+
}
99+
}
100+
for _, id := range errIds {
101+
c.errIdents[id.Name] = &errIdent{Ident: id, wrapped: wrapped}
102+
}
103+
}
104+
}
105+
106+
func (c *checkerImpl) checkReturn(stmt *ast.ReturnStmt, report ReportFunc) {
107+
switch len(stmt.Results) {
108+
case 0:
109+
// Named return values
110+
for n := range c.errNames {
111+
if errIdent, ok := c.errIdents[n]; ok && !errIdent.wrapped {
112+
report(errIdent.Pos(), stmt.Return)
113+
}
114+
}
115+
case len(c.errInReturn):
116+
// Simple return
117+
for i, expr := range stmt.Results {
118+
if !c.errInReturn[i] {
119+
continue
120+
}
121+
switch expr := expr.(type) {
122+
case *ast.Ident:
123+
if errIdent, ok := c.errIdents[expr.Name]; ok && !errIdent.wrapped {
124+
report(errIdent.Pos(), expr.NamePos)
125+
}
126+
case *ast.CallExpr:
127+
if !c.isWrapped(expr) {
128+
report(expr.Pos(), expr.Lparen)
129+
}
130+
default:
131+
// TODO: should report unexpected exper
132+
}
133+
}
134+
case 1:
135+
// Return another function directly
136+
switch expr := stmt.Results[0].(type) {
137+
case *ast.CallExpr:
138+
if !c.isWrapped(expr) {
139+
report(expr.Pos(), expr.Pos())
140+
}
141+
default:
142+
// TODO: should report unexpected exper
143+
}
144+
default:
145+
// TODO: should report unexpected exper
146+
}
147+
}
148+
149+
func (c *checkerImpl) isWrapped(call *ast.CallExpr) bool {
150+
switch fexpr := call.Fun.(type) {
151+
case *ast.SelectorExpr:
152+
if f, ok := c.info.ObjectOf(fexpr.Sel).(*types.Func); ok {
153+
chunks := []string{}
154+
for _, chunk := range strings.Split(f.FullName(), "/") {
155+
if chunk == "vendor" {
156+
chunks = []string{}
157+
} else {
158+
chunks = append(chunks, chunk)
159+
}
160+
}
161+
_, ok = wrapperFuncSet[strings.Join(chunks, "/")]
162+
return ok
163+
}
164+
}
165+
return false
166+
}

0 commit comments

Comments
 (0)