Skip to content

Commit 892ebf9

Browse files
committed
Infer parameters type of the traced function from the original function
1 parent 4f25fe8 commit 892ebf9

File tree

3 files changed

+91
-88
lines changed

3 files changed

+91
-88
lines changed

src/passes/TraceCalls.cpp

Lines changed: 64 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,24 @@
2121
// The pass supports SIMD but the multi-value feature is not supported yet.
2222
//
2323
// Instrumenting void free(void*):
24-
// Before:
25-
// (call $free (i32.const 64))
26-
//
27-
// After:
28-
// (local $1 i32)
29-
// (block
30-
// (call $free
31-
// (local.tee $1
32-
// (i32.const 64)
33-
// )
34-
// )
35-
// (call $trace_free
36-
// (local.get $1)
37-
// )
38-
// )
24+
25+
// Instrumenting function `void* malloc(int32_t)` with a user-defined
26+
// name of the tracer `trace_alloc` and function `void free(void*)`
27+
// with the default name of the tracer `trace_free` (`trace_` prefix
28+
// is added by default):
29+
// wasm-opt --trace-calls=malloc:trace_alloc,free -o test-opt.wasm test.wasm
3930
//
40-
// Instrumenting void* malloc(int32_t) with a user-defined tracer
41-
// (trace_allocation):
4231
// Before:
4332
// (call $malloc
4433
// (local.const 32))
34+
// (call $free (i32.const 64))
4535
//
4636
// After:
4737
// (local $0 i32)
4838
// (local $1 i32)
39+
// (local $2 i32)
4940
// (block (result i32)
50-
// (call $trace_allocation
41+
// (call $trace_alloc
5142
// (local.get $0)
5243
// (local.tee $1
5344
// (call $malloc
@@ -56,9 +47,18 @@
5647
// )
5748
// )
5849
// )
59-
//
50+
// (block
51+
// (call $free
52+
// (local.tee $3
53+
// (i32.const 64)
54+
// )
55+
// )
56+
// (call $trace_free
57+
// (local.get $3)
58+
// )
59+
// )
6060

61-
#include <set>
61+
#include <map>
6262

6363
#include "asmjs/shared-constants.h"
6464
#include "ir/import-utils.h"
@@ -68,45 +68,17 @@
6868

6969
namespace wasm {
7070

71-
struct TracedFunction {
72-
Name originName;
73-
Name tracerName;
74-
Type params;
75-
Type results;
76-
};
77-
bool operator<(const TracedFunction& lhs, const TracedFunction& rhs) {
78-
return lhs.originName < rhs.originName;
79-
}
80-
81-
static Type strToType(const std::string& strType) {
82-
if (strType == "i32") {
83-
return Type::i32;
84-
} else if (strType == "i64") {
85-
return Type::i64;
86-
} else if (strType == "f32") {
87-
return Type::f32;
88-
} else if (strType == "f64") {
89-
return Type::f64;
90-
} else if (strType == "v128") {
91-
return Type::v128;
92-
} else {
93-
Fatal() << "Failed to parse type '" << strType << "'";
94-
}
95-
}
71+
using TracedFunctions = std::map<Name /* originName */, Name /* tracerName */>;
9672

9773
struct AddTraceWrappers : public WalkerPass<PostWalker<AddTraceWrappers>> {
98-
AddTraceWrappers(std::set<TracedFunction> tracedFunctions)
74+
AddTraceWrappers(TracedFunctions tracedFunctions)
9975
: tracedFunctions(std::move(tracedFunctions)) {}
10076
void visitCall(Call* curr) {
10177
auto* target = getModule()->getFunction(curr->target);
10278

103-
auto iter = std::find_if(tracedFunctions.begin(),
104-
tracedFunctions.end(),
105-
[target](const TracedFunction& f) {
106-
return f.originName == target->name;
107-
});
79+
auto iter = tracedFunctions.find(target->name);
10880
if (iter != tracedFunctions.end()) {
109-
addInstrumentation(curr, target, iter->tracerName);
81+
addInstrumentation(curr, target, iter->second);
11082
}
11183
}
11284

@@ -136,14 +108,15 @@ struct AddTraceWrappers : public WalkerPass<PostWalker<AddTraceWrappers>> {
136108
{builder.makeCall(
137109
wrapperName, trackerCallParams, Type::BasicType::none),
138110
builder.makeLocalGet(resultLocal, resultType)}));
139-
} else
111+
} else {
140112
replaceCurrent(builder.makeBlock(
141113
{realCall,
142114
builder.makeCall(
143115
wrapperName, trackerCallParams, Type::BasicType::none)}));
116+
}
144117
}
145118

146-
std::set<TracedFunction> tracedFunctions;
119+
TracedFunctions tracedFunctions;
147120
};
148121

149122
struct TraceCalls : public Pass {
@@ -154,39 +127,55 @@ struct TraceCalls : public Pass {
154127
auto functionsDefinitions = getPassOptions().getArgument(
155128
"trace-calls",
156129
"TraceCalls usage: wasm-opt "
157-
"--trace-calls=FUNCTION_TO_TRACE[:TRACER_NAME][,RESULT1_TYPE]"
158-
"[,PARAM1_TYPE[,PARAM2_TYPE[,...]]][;...]");
130+
"--trace-calls=FUNCTION_TO_TRACE[:TRACER_NAME][,...]");
159131

160132
auto tracedFunctions = parseArgument(functionsDefinitions);
161133

162134
for (const auto& tracedFunction : tracedFunctions) {
163-
addImport(module, tracedFunction);
135+
auto func = module->getFunctionOrNull(tracedFunction.first);
136+
if (!func) {
137+
std::cerr << "[TraceCalls] Function '" << tracedFunction.first
138+
<< "' not found" << std::endl;
139+
} else {
140+
addImport(module, *func, tracedFunction.second);
141+
}
164142
}
165143

166144
AddTraceWrappers(std::move(tracedFunctions)).run(getPassRunner(), module);
167145
}
168146

169147
private:
170-
std::set<TracedFunction> parseArgument(const std::string& arg) {
171-
std::set<TracedFunction> tracedFunctions;
148+
Type getTracerParamsType(ImportInfo& info, const Function& func) {
149+
auto resultsType = func.type.getSignature().results;
150+
if (resultsType.isTuple()) {
151+
Fatal() << "Failed to instrument function '" << func.name
152+
<< "': Multi-value result type is not supported";
153+
}
172154

173-
for (const auto& definition : String::Split(arg, ";")) {
174-
auto parts = String::Split(definition, ",");
175-
if (parts.size() == 0) {
155+
std::vector<Type> tracerParamTypes;
156+
if (resultsType.isConcrete()) {
157+
tracerParamTypes.push_back(resultsType);
158+
}
159+
for (auto& op : func.type.getSignature().params) {
160+
tracerParamTypes.push_back(op);
161+
}
162+
163+
return Type(tracerParamTypes);
164+
}
165+
166+
TracedFunctions parseArgument(const std::string& arg) {
167+
TracedFunctions tracedFunctions;
168+
169+
for (const auto& definition : String::Split(arg, ",")) {
170+
if (definition.empty()) {
176171
// Empty definition, ignore.
177172
continue;
178173
}
179174

180175
std::string originName, traceName;
181-
parseFunctionName(parts[0], originName, traceName);
182-
183-
std::vector<Type> paramsAndResults;
184-
for (size_t i = 1; i < parts.size(); i++) {
185-
paramsAndResults.push_back(strToType(parts[i]));
186-
}
176+
parseFunctionName(definition, originName, traceName);
187177

188-
tracedFunctions.emplace(TracedFunction{
189-
Name(originName), Name(traceName), paramsAndResults, Type::none});
178+
tracedFunctions[Name(originName)] = Name(traceName);
190179
}
191180

192181
return tracedFunctions;
@@ -211,14 +200,14 @@ struct TraceCalls : public Pass {
211200
}
212201
}
213202

214-
void addImport(Module* wasm, const TracedFunction& f) {
203+
void addImport(Module* wasm, const Function& f, const Name& tracerName) {
215204
ImportInfo info(*wasm);
216205

217-
if (!info.getImportedFunction(ENV, f.tracerName)) {
218-
auto import =
219-
Builder::makeFunction(f.tracerName, Signature(f.params, f.results), {});
206+
if (!info.getImportedFunction(ENV, tracerName)) {
207+
auto import = Builder::makeFunction(
208+
tracerName, Signature(getTracerParamsType(info, f), Type::none), {});
220209
import->module = ENV;
221-
import->base = f.tracerName;
210+
import->base = tracerName;
222211
wasm->addFunction(std::move(import));
223212
}
224213
}

test/lit/passes/trace-calls.wast

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
11
;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited.
22

3-
;; RUN: foreach %s %t wasm-opt --enable-simd --trace-calls="noparamsnoresults;singleparamnoresults,f64;multiparamsnoresults:tracempnr,i32,i64,f32;noparamssingleresult,v128;multiparamssingleresult,v128,i32,v128" -S -o - | filecheck %s
3+
;; RUN: wasm-opt --enable-simd --trace-calls="noparamsnoresults,singleparamnoresults,multiparamsnoresults:tracempnr,noparamssingleresult,multiparamssingleresult" %s -S -o - | filecheck %s
44

55
(module
6+
7+
(import "env" "no_params_no_results"
8+
(func $noparamsnoresults))
9+
(import "env" "single_param_no_results"
10+
(func $singleparamnoresults (param f64)))
11+
(import "env" "multi_params_no_results"
12+
(func $multiparamsnoresults (param i32 i64 f32)))
13+
(import "env" "no_params_single_result"
14+
(func $noparamssingleresult (result v128)))
15+
(import "env" "multi_params_single_result"
16+
(func $multiparamssingleresult (param i32 v128)(result v128)))
17+
(import "env" "dont_trace_me"
18+
(func $donttraceme))
19+
20+
621
;; CHECK: (type $0 (func))
722

823
;; CHECK: (type $1 (func (result v128)))
@@ -18,15 +33,7 @@
1833
;; CHECK: (type $6 (func (param v128)))
1934

2035
;; CHECK: (import "env" "no_params_no_results" (func $noparamsnoresults))
21-
(import "env" "no_params_no_results" (func $noparamsnoresults))
22-
(import "env" "single_param_no_results"
23-
(func $singleparamnoresults (param f64)))
24-
(import "env" "multi_params_no_results"
25-
(func $multiparamsnoresults (param i32 i64 f32)))
26-
(import "env" "no_params_single_result"
27-
(func $noparamssingleresult (result v128)))
28-
(import "env" "multi_params_single_result"
29-
(func $multiparamssingleresult (param i32 v128)(result v128)))
36+
3037
;; CHECK: (import "env" "single_param_no_results" (func $singleparamnoresults (param f64)))
3138

3239
;; CHECK: (import "env" "multi_params_no_results" (func $multiparamsnoresults (param i32 i64 f32)))
@@ -36,8 +43,6 @@
3643
;; CHECK: (import "env" "multi_params_single_result" (func $multiparamssingleresult (param i32 v128) (result v128)))
3744

3845
;; CHECK: (import "env" "dont_trace_me" (func $donttraceme))
39-
(import "env" "dont_trace_me" (func $donttraceme))
40-
4146

4247
;; CHECK: (import "env" "tracempnr" (func $tracempnr (param i32 i64 f32)))
4348

@@ -146,6 +151,6 @@
146151
;; CHECK-NEXT: (call $donttraceme)
147152
;; CHECK-NEXT: )
148153
(func $test_dont_trace_me
149-
call $donttraceme
154+
(call $donttraceme)
150155
)
151156
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
;; Test that a traced function with a multi-value result type
2+
;; results in a useful error message
3+
4+
;; RUN: not wasm-opt --enable-simd --enable-multivalue --trace-calls=multi_param_result %s 2>&1 | filecheck %s
5+
6+
;; CHECK: Fatal: Failed to instrument function 'multi_param_result': Multi-value result type is not supported
7+
(module
8+
(import "env" "multi_param_result" (func $multi_param_result (result i32 i32)))
9+
)

0 commit comments

Comments
 (0)