Merged
Changes from 1 commit
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Failed to load files.
PrevPrevious commit
Next Next commit
Added more tests
  • Loading branch information
@cehongwang
cehongwang committedJun 16, 2025
commit 32c851fa7c5916b51933c328394d02f2a924e3c1
Original file line numberDiff line numberDiff line change
Expand Up@@ -372,6 +372,44 @@ def test_resnet18_dynamic(ir):
)


@pytest.mark.unit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just add this to help @lanluo-nvidia out

@unittest.skipIf(
        not importlib.util.find_spec("torchvision"), "torchvision not installed"
    )

Then feel free to merge

def test_resnet18_dynamic_torch_exec_ops(ir):
"""
This tests export save and load functionality on Resnet18 model
"""
model = models.resnet18().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32,
name="x",
)
],
"ir": ir,
"min_block_size": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
"torch_executed_ops": {torch.ops.aten.addmm, "torch.ops.aten.add"},
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = deser_trt_module(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_hybrid_conv_fallback(ir):
"""
Expand Down
Loading