1+ import torch
2+ import pytest
3+ from vector_quantize_pytorch import LFQ
4+ import math
5+ """
6+ testing_strategy:
7+ subdivisions: using masks, using frac_per_sample_entropy < 1
8+ """
9+
10+ torch .manual_seed (0 )
11+
12+ @pytest .mark .parametrize ('frac_per_sample_entropy' , (1. , 0.5 ))
13+ def test_masked_lfq (
14+ frac_per_sample_entropy
15+ ):
16+ # you can specify either dim or codebook_size
17+ # if both specified, will be validated against each other
18+
19+ quantizer = LFQ (
20+ codebook_size = 65536 , # codebook size, must be a power of 2
21+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
22+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
23+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
24+ frac_per_sample_entropy = frac_per_sample_entropy
25+ )
26+
27+ image_feats = torch .randn (2 , 16 , 32 , 32 )
28+
29+ ret , loss_breakdown = quantizer (image_feats , inv_temperature = 100. , return_loss_breakdown = True ) # you may want to experiment with temperature
30+
31+ quantized , indices , _ = ret
32+ assert (quantized == quantizer .indices_to_codes (indices )).all ()
33+
34+ @pytest .mark .parametrize ('frac_per_sample_entropy' , (0.1 ,))
35+ @pytest .mark .parametrize ('iters' , (10 ,))
36+ @pytest .mark .parametrize ('mask' , (None , torch .tensor ([True , False ])))
37+ def test_lfq_bruteforce_frac_per_sample_entropy (frac_per_sample_entropy , iters , mask ):
38+ image_feats = torch .randn (2 , 16 , 32 , 32 )
39+
40+ full_per_sample_entropy_quantizer = LFQ (
41+ codebook_size = 65536 , # codebook size, must be a power of 2
42+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
43+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
44+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
45+ frac_per_sample_entropy = 1
46+ )
47+
48+ partial_per_sample_entropy_quantizer = LFQ (
49+ codebook_size = 65536 , # codebook size, must be a power of 2
50+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
51+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
52+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
53+ frac_per_sample_entropy = frac_per_sample_entropy
54+ )
55+
56+ ret , loss_breakdown = full_per_sample_entropy_quantizer (
57+ image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask )
58+ true_per_sample_entropy = loss_breakdown .per_sample_entropy
59+
60+ per_sample_losses = torch .zeros (iters )
61+ for iter in range (iters ):
62+ ret , loss_breakdown = partial_per_sample_entropy_quantizer (
63+ image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask ) # you may want to experiment with temperature
64+
65+ quantized , indices , _ = ret
66+ assert (quantized == partial_per_sample_entropy_quantizer .indices_to_codes (indices )).all ()
67+ per_sample_losses [iter ] = loss_breakdown .per_sample_entropy
68+ # 95% confidence interval
69+ assert abs (per_sample_losses .mean () - true_per_sample_entropy ) \
70+ < (1.96 * (per_sample_losses .std () / math .sqrt (iters )))
71+
72+ print ("difference: " , abs (per_sample_losses .mean () - true_per_sample_entropy ))
73+ print ("std error:" , (1.96 * (per_sample_losses .std () / math .sqrt (iters ))))
0 commit comments