diff --git a/examples/test_infer.py b/examples/test_infer.py index a90a5d3e..3a143f90 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -18,6 +18,9 @@ def test( attn_backend="default", image_path=None, skip_load=False, + num_blocks=512, + block_size=256, + max_cache_len=4096, ): model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -39,6 +42,9 @@ def test( enable_graph=enable_graph, attn_backend=attn_backend, skip_load=skip_load, + num_blocks=num_blocks, + block_size=block_size, + max_cache_len=max_cache_len, ) conversations = [ @@ -103,4 +109,7 @@ def test( attn_backend=cfg.attn, image_path=cfg.image, skip_load=cfg.skip_load, + num_blocks=cfg.num_blocks, + block_size=cfg.block_size, + max_cache_len=cfg.max_cache_len, )