|
1 | 1 | package wraperr |
2 | 2 |
|
3 | 3 | import ( |
4 | | -"bufio" |
5 | 4 | "go/ast" |
6 | 5 | "go/token" |
7 | | -"go/types" |
8 | | -"os" |
9 | | -"strings" |
10 | 6 |
|
11 | 7 | "golang.org/x/tools/go/analysis" |
12 | 8 | "golang.org/x/tools/go/analysis/passes/inspect" |
@@ -57,177 +53,22 @@ func init() { |
57 | 53 | } |
58 | 54 |
|
59 | 55 | func runAnalyze(pass *analysis.Pass) (interface{}, error) { |
60 | | -inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) |
| 56 | +r := newFileReader(pass.Fset) |
61 | 57 |
|
62 | | -isWrapped := func(call *ast.CallExpr) bool { |
63 | | -switch fexpr := call.Fun.(type) { |
64 | | -case *ast.SelectorExpr: |
65 | | -if f, ok := pass.TypesInfo.ObjectOf(fexpr.Sel).(*types.Func); ok { |
66 | | -chunks := []string{} |
67 | | -for _, chunk := range strings.Split(f.FullName(), "/") { |
68 | | -if chunk == "vendor" { |
69 | | -chunks = []string{} |
70 | | -} else { |
71 | | -chunks = append(chunks, chunk) |
72 | | -} |
73 | | -} |
74 | | -_, ok = wrapperFuncSet[strings.Join(chunks, "/")] |
75 | | -return ok |
76 | | -} |
77 | | -} |
78 | | -return false |
| 58 | +reportFunc := func(assignedAt, returnedAt token.Pos) { |
| 59 | +occPos := pass.Fset.Position(assignedAt) |
| 60 | +line := sprintInlineCode(r.GetLine(assignedAt)) |
| 61 | +pass.Reportf(returnedAt, "the error is assigned on L%d: %s", occPos.Line, line) |
79 | 62 | } |
80 | 63 |
|
| 64 | +inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) |
81 | 65 | nodeFilter := []ast.Node{ |
82 | 66 | (*ast.FuncDecl)(nil), |
83 | 67 | } |
84 | 68 |
|
85 | 69 | inspect.Preorder(nodeFilter, func(n ast.Node) { |
86 | | -f := n.(*ast.FuncDecl) |
87 | | -if f.Type == nil || f.Type.Results == nil { |
88 | | -return |
89 | | -} |
90 | | - |
91 | | -fields := f.Type.Results.List |
92 | | -errInReturn := make([]bool, 0, 2*len(fields)) |
93 | | -errNames := make(map[string]struct{}, 2*len(fields)) |
94 | | -errIdents := map[string]*errIdent{} |
95 | | -var ok bool |
96 | | -for _, f := range fields { |
97 | | -isErr := isErrorType(pass.TypesInfo.TypeOf(f.Type)) |
98 | | -ok = ok || isErr |
99 | | -if len(f.Names) == 0 { |
100 | | -errInReturn = append(errInReturn, isErr) |
101 | | -} |
102 | | -for _, n := range f.Names { |
103 | | -errInReturn = append(errInReturn, isErr) |
104 | | -errNames[n.Name] = struct{}{} |
105 | | -} |
106 | | -} |
107 | | -if !ok { |
108 | | -return |
109 | | -} |
110 | | - |
111 | | -lines := make(map[string][]string) |
112 | | - |
113 | | -// ref: https://github.com/kisielk/errcheck/blob/1787c4bee836470bf45018cfbc783650db3c6501/internal/errcheck/errcheck.go#L488-L498 |
114 | | -getLine := func(tp token.Pos) string { |
115 | | -pos := pass.Fset.Position(tp) |
116 | | -foundLines, ok := lines[pos.Filename] |
117 | | - |
118 | | -if !ok { |
119 | | -f, err := os.Open(pos.Filename) |
120 | | -if err == nil { |
121 | | -sc := bufio.NewScanner(f) |
122 | | -for sc.Scan() { |
123 | | -foundLines = append(foundLines, sc.Text()) |
124 | | -} |
125 | | -lines[pos.Filename] = foundLines |
126 | | -f.Close() |
127 | | -} |
128 | | -} |
129 | | - |
130 | | -line := "??" |
131 | | -if pos.Line-1 < len(foundLines) { |
132 | | -line = strings.TrimSpace(foundLines[pos.Line-1]) |
133 | | -} |
134 | | - |
135 | | -return line |
136 | | -} |
137 | | - |
138 | | -recordUnwrappedError := func(occurredAt, returnedAt token.Pos) { |
139 | | -occPos := pass.Fset.Position(occurredAt) |
140 | | -pass.Reportf(returnedAt, "the error is assigned on L%d: %s", occPos.Line, sprintInlineCode(getLine(occurredAt))) |
141 | | -} |
142 | | - |
143 | | -ast.Inspect(f, func(n ast.Node) bool { |
144 | | -switch stmt := n.(type) { |
145 | | -case *ast.AssignStmt: |
146 | | -var errIds []*ast.Ident |
147 | | -for _, expr := range stmt.Lhs { |
148 | | -if !isErrorType(pass.TypesInfo.TypeOf(expr)) { |
149 | | -continue |
150 | | -} |
151 | | -if id, ok := expr.(*ast.Ident); ok { |
152 | | -errIds = append(errIds, id) |
153 | | -} |
154 | | -} |
155 | | -if len(errIds) > 0 { |
156 | | -// Detect wrapped error assignment |
157 | | -var wrapped bool |
158 | | -for _, expr := range stmt.Rhs { |
159 | | -if cexpr, ok := expr.(*ast.CallExpr); ok { |
160 | | -if isWrapped(cexpr) { |
161 | | -wrapped = true |
162 | | -} |
163 | | -} |
164 | | -} |
165 | | -for _, id := range errIds { |
166 | | -errIdents[id.Name] = &errIdent{Ident: id, wrapped: wrapped} |
167 | | -} |
168 | | -} |
169 | | -case *ast.ReturnStmt: |
170 | | -switch len(stmt.Results) { |
171 | | -case 0: |
172 | | -// Named return values |
173 | | -for n := range errNames { |
174 | | -if errIdent, ok := errIdents[n]; ok && !errIdent.wrapped { |
175 | | -recordUnwrappedError(errIdent.Pos(), stmt.Return) |
176 | | -} |
177 | | -} |
178 | | -case len(errInReturn): |
179 | | -// Simple return |
180 | | -for i, expr := range stmt.Results { |
181 | | -if !errInReturn[i] { |
182 | | -continue |
183 | | -} |
184 | | -switch expr := expr.(type) { |
185 | | -case *ast.Ident: |
186 | | -if errIdent, ok := errIdents[expr.Name]; ok && !errIdent.wrapped { |
187 | | -recordUnwrappedError(errIdent.Pos(), expr.NamePos) |
188 | | -} |
189 | | -case *ast.CallExpr: |
190 | | -if !isWrapped(expr) { |
191 | | -recordUnwrappedError(expr.Pos(), expr.Lparen) |
192 | | -} |
193 | | -default: |
194 | | -// TODO: should report unexpected exper |
195 | | -} |
196 | | -} |
197 | | -case 1: |
198 | | -// Return another function directly |
199 | | -switch expr := stmt.Results[0].(type) { |
200 | | -case *ast.CallExpr: |
201 | | -if !isWrapped(expr) { |
202 | | -recordUnwrappedError(expr.Pos(), expr.Pos()) |
203 | | -} |
204 | | -default: |
205 | | -// TODO: should report unexpected exper |
206 | | -} |
207 | | -default: |
208 | | -// TODO: should report unexpected exper |
209 | | -} |
210 | | -} |
211 | | -return true |
212 | | -}) |
| 70 | +NewChecker(pass.Fset, pass.TypesInfo, n.(*ast.FuncDecl)).Check(reportFunc) |
213 | 71 | }) |
214 | 72 |
|
215 | 73 | return nil, nil |
216 | 74 | } |
217 | | - |
218 | | -func sprintInlineCode(s string) string { |
219 | | -cc := 1 |
220 | | -c := cc |
221 | | -for _, r := range s { |
222 | | -if r == '`' { |
223 | | -cc++ |
224 | | -if cc > c { |
225 | | -c = cc |
226 | | -} |
227 | | -} else { |
228 | | -cc = 1 |
229 | | -} |
230 | | -} |
231 | | -q := strings.Repeat("`", c) |
232 | | -return q + s + q |
233 | | -} |
0 commit comments