Skip to content
Prev Previous commit
Next Next commit
Added test of fast refitting
  • Loading branch information
cehongwang committed Aug 13, 2024
commit 3d3d59dae9b83554a8c49fa271c8236c054476b9
269 changes: 264 additions & 5 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,260 @@ def test_mapping():
torch._dynamo.reset()


@pytest.mark.unit
def test_fast_refit_one_engine():

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))
exp_program2 = torch.export.export(model2, tuple(inputs))

trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
)

new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
fast_refit=True,
)

# Check the output
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assertions.assertTrue(
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
"Refit Result is not correct. Refit failed",
)
# Clean up model env

torch._dynamo.reset()


@pytest.mark.unit
def test_fast_refit_one_engine_bert():
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
model2 = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
nn.init.xavier_normal_(model2.embeddings.word_embeddings.weight)
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))
exp_program2 = torch.export.export(model2, tuple(inputs))

trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
)

new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
fast_refit=True,
)

# Check the output
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
if not isinstance(expected_output, torch.Tensor) or not isinstance(
refitted_output, torch.Tensor
):
continue
assertions.assertTrue(
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
"Refit Result is not correct. Refit failed",
)
# Clean up model env

torch._dynamo.reset()


@pytest.mark.unit
def test_fast_refit_one_engine_inline_runtime():
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))
exp_program2 = torch.export.export(model2, tuple(inputs))

trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
)
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
trt_gm = torch.export.load(trt_ep_path)
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
fast_refit=True,
)

# Check the output
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assertions.assertTrue(
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
"Refit Result is not correct. Refit failed",
)
# Clean up model env

torch._dynamo.reset()


@pytest.mark.unit
def test_fast_refit_one_engine_python_runtime():

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = True

exp_program = torch.export.export(model, tuple(inputs))
exp_program2 = torch.export.export(model2, tuple(inputs))

trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
)

new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
fast_refit=True,
)

# Check the output
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assertions.assertTrue(
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
"Refit Result is not correct. Refit failed",
)

# Clean up model env
torch._dynamo.reset()


@pytest.mark.unit
def test_fast_refit_multiple_engine():

class net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
self.bn = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
self.fc1 = nn.Linear(12 * 56 * 56, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.bn(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
return self.fc1(x)

model = net().eval().to("cuda")
model2 = net().eval().to("cuda")

inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))
exp_program2 = torch.export.export(model2, tuple(inputs))

torch_executed_ops = {torch.ops.aten.convolution.default}
trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
torch_executed_ops=torch_executed_ops,
)

new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
fast_refit=True,
)

# Check the output
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assertions.assertTrue(
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
"Refit Result is not correct. Refit failed",
)
# Clean up model env

torch._dynamo.reset()


@pytest.mark.unit
def test_refit_one_engine():

Expand Down Expand Up @@ -108,7 +362,8 @@ def test_refit_one_engine():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
arg_inputs=inputs,
inputs=inputs,
fast_refit=False,
)

# Check the output
Expand Down Expand Up @@ -154,7 +409,8 @@ def test_refit_one_engine_bert():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
arg_inputs=inputs,
inputs=inputs,
fast_refit=False,
)

# Check the output
Expand Down Expand Up @@ -203,7 +459,8 @@ def test_refit_one_engine_inline_runtime():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
arg_inputs=inputs,
inputs=inputs,
fast_refit=False,
)

# Check the output
Expand Down Expand Up @@ -247,7 +504,8 @@ def test_refit_one_engine_python_runtime():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
arg_inputs=inputs,
inputs=inputs,
fast_refit=False,
)

# Check the output
Expand Down Expand Up @@ -313,7 +571,8 @@ def forward(self, x):
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
arg_inputs=inputs,
inputs=inputs,
fast_refit=False,
)

# Check the output
Expand Down