Skip to content

Commit 5b80add

Browse files
Turn off cuda malloc by default when --fast autotune is turned on. (#10393)
1 parent 9da397e commit 5b80add

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

comfy/model_management.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,9 @@ def amd_min_version(device=None, min_rdna_version=0):
371371
except:
372372
pass
373373

374+
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
375+
torch.backends.cudnn.benchmark = True
376+
374377
try:
375378
if torch_version_numeric >= (2, 5):
376379
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)

comfy/ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
6767

6868
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
6969

70-
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
71-
torch.backends.cudnn.benchmark = True
72-
7370
def cast_to_input(weight, input, non_blocking=False, copy=True):
7471
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
7572

cuda_malloc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import importlib.util
3-
from comfy.cli_args import args
3+
from comfy.cli_args import args, PerformanceFeature
44
import subprocess
55

66
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
@@ -75,8 +75,9 @@ def cuda_malloc_supported():
7575
spec.loader.exec_module(module)
7676
version = module.__version__
7777

78-
if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
79-
args.cuda_malloc = cuda_malloc_supported()
78+
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
79+
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
80+
args.cuda_malloc = cuda_malloc_supported()
8081
except:
8182
pass
8283

0 commit comments

Comments
 (0)