3131# > conda create -n myenv python=3.10
3232# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
3333# > pip install git+https://github.com/facebookresearch/segment-anything.git
34- # > pip install git+https://github.com/pytorch-labs /ao.git
34+ # > pip install git+https://github.com/pytorch/ao.git
3535#
3636# Segment Anything Model checkpoint setup:
3737#
4444#
4545
4646import torch
47- from torchao .quantization .quant_api import quantize_ , int8_dynamic_activation_int8_weight
47+ from torchao .quantization .quant_api import quantize_ , Int8DynamicActivationInt8WeightConfig
4848from torchao .utils import unwrap_tensor_subclass , TORCH_VERSION_AT_LEAST_2_5
4949from segment_anything import sam_model_registry
5050from torch .utils .benchmark import Timer
@@ -143,7 +143,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
143143# for improvements.
144144#
145145# Next, let's apply quantization. Quantization for GPUs comes in three main forms
146- # in `torchao <https://github.com/pytorch-labs /ao>`_ which is just native
146+ # in `torchao <https://github.com/pytorch/ao>`_ which is just native
147147# pytorch+python code. This includes:
148148#
149149# * int8 dynamic quantization
@@ -157,9 +157,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
157157# in memory bound situations where the benefit comes from loading less
158158# weight data, rather than doing less computation. The torchao APIs:
159159#
160- # ``int8_dynamic_activation_int8_weight ()``,
161- # ``int8_weight_only ()`` or
162- # ``int4_weight_only ()``
160+ # ``Int8DynamicActivationInt8WeightConfig ()``,
161+ # ``Int8WeightOnlyConfig ()`` or
162+ # ``Int4WeightOnlyConfig ()``
163163#
164164# can be used to easily apply the desired quantization technique and then
165165# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
@@ -171,7 +171,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
171171# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
172172# above (no replacement for int4).
173173#
174- # The difference between the two APIs is that ``int8_dynamic_activation `` API
174+ # The difference between the two APIs is that the ``Int8DynamicActivationInt8WeightConfig `` API
175175# alters the weight tensor of the linear module so instead of doing a
176176# normal linear, it does a quantized operation. This is helpful when you
177177# have non-standard linear ops that do more than one thing. The ``apply``
@@ -186,7 +186,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
186186model , image = get_sam_model (only_one_block , batchsize )
187187model = model .to (torch .bfloat16 )
188188image = image .to (torch .bfloat16 )
189- quantize_ (model , int8_dynamic_activation_int8_weight ())
189+ quantize_ (model , Int8DynamicActivationInt8WeightConfig ())
190190if not TORCH_VERSION_AT_LEAST_2_5 :
191191 # needed for subclass + compile to work on older versions of pytorch
192192 unwrap_tensor_subclass (model )
@@ -224,7 +224,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
224224model = model .to (torch .bfloat16 )
225225image = image .to (torch .bfloat16 )
226226torch ._inductor .config .force_fuse_int_mm_with_mul = True
227- quantize_ (model , int8_dynamic_activation_int8_weight ())
227+ quantize_ (model , Int8DynamicActivationInt8WeightConfig ())
228228if not TORCH_VERSION_AT_LEAST_2_5 :
229229 # needed for subclass + compile to work on older versions of pytorch
230230 unwrap_tensor_subclass (model )
@@ -258,7 +258,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
258258torch ._inductor .config .coordinate_descent_tuning = True
259259torch ._inductor .config .coordinate_descent_check_all_directions = True
260260torch ._inductor .config .force_fuse_int_mm_with_mul = True
261- quantize_ (model , int8_dynamic_activation_int8_weight ())
261+ quantize_ (model , Int8DynamicActivationInt8WeightConfig ())
262262if not TORCH_VERSION_AT_LEAST_2_5 :
263263 # needed for subclass + compile to work on older versions of pytorch
264264 unwrap_tensor_subclass (model )
@@ -290,7 +290,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
290290 model , image = get_sam_model (False , batchsize )
291291 model = model .to (torch .bfloat16 )
292292 image = image .to (torch .bfloat16 )
293- quantize_ (model , int8_dynamic_activation_int8_weight ())
293+ quantize_ (model , Int8DynamicActivationInt8WeightConfig ())
294294 if not TORCH_VERSION_AT_LEAST_2_5 :
295295 # needed for subclass + compile to work on older versions of pytorch
296296 unwrap_tensor_subclass (model )
@@ -315,6 +315,6 @@ def get_sam_model(only_one_block=False, batchsize=1):
315315# the model. For example, this can be done with some form of flash attention.
316316#
317317# For more information visit
318- # `torchao <https://github.com/pytorch-labs /ao>`_ and try it on your own
318+ # `torchao <https://github.com/pytorch/ao>`_ and try it on your own
319319# models.
320320#
0 commit comments