|
| 1 | +from typing import Any, Callable |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pytest |
| 5 | +from pyversity import cover, diversify, dpp, mmr, msd |
| 6 | +from pyversity.datatypes import Metric, Strategy |
| 7 | + |
| 8 | + |
| 9 | +def test_mmr() -> None: |
| 10 | + """Test MMR strategy with various lambda settings.""" |
| 11 | + # Pure relevance (lambda=1): picks top-k by scores |
| 12 | + emb = np.eye(5, dtype=np.float32) |
| 13 | + scores = np.array([0.1, 0.9, 0.3, 0.8, 0.2], dtype=np.float32) |
| 14 | + idx, gains = mmr(emb, scores, k=3, lambda_param=1.0, metric=Metric.COSINE, normalize=True) |
| 15 | + expected = np.array([1, 3, 2], dtype=np.int32) |
| 16 | + assert np.array_equal(idx, expected) |
| 17 | + assert np.allclose(gains, scores[expected]) |
| 18 | + |
| 19 | + # Strong diversity (lambda=0): avoid near-duplicate |
| 20 | + emb = np.array([[1.0, 0.0], [0.999, 0.001], [0.0, 1.0]], dtype=np.float32) |
| 21 | + scores = np.array([1.0, 0.99, 0.98], dtype=np.float32) |
| 22 | + idx, _ = mmr(emb, scores, k=2, lambda_param=0.0, metric=Metric.COSINE, normalize=True) |
| 23 | + assert idx[0] == 0 and idx[1] == 2 |
| 24 | + |
| 25 | + # Balanced (lambda=0.5): picks mix of relevance and diversity |
| 26 | + idx, _ = mmr(emb, scores, k=2, lambda_param=0.5, metric=Metric.COSINE, normalize=True) |
| 27 | + assert idx[0] == 0 and idx[1] == 2 |
| 28 | + |
| 29 | + # Bounds check |
| 30 | + with pytest.raises(ValueError): |
| 31 | + mmr(np.eye(2, dtype=np.float32), np.array([1.0, 0.5], dtype=np.float32), k=1, lambda_param=-0.1) |
| 32 | + |
| 33 | + |
| 34 | +def test_msd() -> None: |
| 35 | + """Test MSD strategy with various lambda settings.""" |
| 36 | + # Pure relevance (lambda=1): picks top-k by scores |
| 37 | + emb = np.eye(4, dtype=np.float32) |
| 38 | + scores = np.array([0.5, 0.2, 0.9, 0.1], dtype=np.float32) |
| 39 | + idx, _ = msd(emb, scores, k=2, lambda_param=1.0, metric=Metric.COSINE, normalize=True) |
| 40 | + assert np.array_equal(idx, np.array([2, 0], dtype=np.int32)) |
| 41 | + |
| 42 | + # Strong diversity (lambda=0): picks most dissimilar |
| 43 | + emb = np.array([[1.0, 0.0], [0.999, 0.001], [0.0, 1.0]], dtype=np.float32) |
| 44 | + scores = np.array([1.0, 0.99, 0.98], dtype=np.float32) |
| 45 | + idx, _ = msd(emb, scores, k=2, lambda_param=0.0, metric=Metric.COSINE, normalize=True) |
| 46 | + assert idx[0] == 0 and idx[1] == 2 |
| 47 | + |
| 48 | + # Balanced (lambda=0.5): picks mix of relevance and diversity |
| 49 | + idx, _ = msd(emb, scores, k=2, lambda_param=0.5, metric=Metric.COSINE, normalize=True) |
| 50 | + assert idx[0] == 0 and idx[1] == 2 |
| 51 | + |
| 52 | + # Bounds check |
| 53 | + with pytest.raises(ValueError): |
| 54 | + msd(np.eye(2, dtype=np.float32), np.array([1.0, 0.5], dtype=np.float32), k=1, lambda_param=1.1) |
| 55 | + |
| 56 | + |
| 57 | +def test_cover() -> None: |
| 58 | + """Test COVER strategy with various theta and gamma settings.""" |
| 59 | + emb = np.eye(3, dtype=np.float32) |
| 60 | + scores = np.array([0.1, 0.8, 0.3], dtype=np.float32) |
| 61 | + |
| 62 | + # Pure relevance (theta=1): picks top-k by scores |
| 63 | + idx, gains = cover(emb, scores, k=2, theta=1.0) |
| 64 | + expected = np.array([1, 2], dtype=np.int32) |
| 65 | + assert np.array_equal(idx, expected) |
| 66 | + assert np.allclose(gains, scores[expected]) |
| 67 | + |
| 68 | + # Balanced coverage (theta=0.5, gamma=0.5): picks diverse set |
| 69 | + idx, _ = cover(emb, scores, k=2, theta=0.5, gamma=0.5) |
| 70 | + assert idx[0] == 1 and idx[1] in (0, 2) |
| 71 | + |
| 72 | + # Parameter validation |
| 73 | + with pytest.raises(ValueError): |
| 74 | + cover(emb, scores, k=2, theta=-0.01) |
| 75 | + with pytest.raises(ValueError): |
| 76 | + cover(emb, scores, k=2, theta=1.01) |
| 77 | + with pytest.raises(ValueError): |
| 78 | + cover(emb, scores, k=2, gamma=0.0) |
| 79 | + with pytest.raises(ValueError): |
| 80 | + cover(emb, scores, k=2, gamma=-0.5) |
| 81 | + |
| 82 | + |
| 83 | +def test_dpp() -> None: |
| 84 | + """Test DPP strategy with various beta settings.""" |
| 85 | + emb = np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=np.float32) |
| 86 | + scores = np.array([0.1, 0.2, 0.3], dtype=np.float32) |
| 87 | + |
| 88 | + # Beta=0: ignore relevance, diversity-only kernel |
| 89 | + idx, gains = dpp(emb, scores, k=3, beta=0.0) |
| 90 | + assert 1 <= idx.size <= 3 |
| 91 | + assert np.all(gains >= -1e-7) |
| 92 | + assert np.all(gains[:-1] + 1e-7 >= gains[1:]) |
| 93 | + |
| 94 | + # Strong diversity (beta=1) |
| 95 | + idx, gains = dpp(emb, scores, k=2, beta=1.0) |
| 96 | + assert 1 <= idx.size <= 2 |
| 97 | + assert np.all(gains >= -1e-7) |
| 98 | + assert np.all(gains[:-1] + 1e-7 >= gains[1:]) |
| 99 | + |
| 100 | + # Balanced (beta=0.5) |
| 101 | + idx, gains = dpp(emb, scores, k=2, beta=0.5) |
| 102 | + assert 1 <= idx.size <= 2 |
| 103 | + assert np.all(gains >= -1e-7) |
| 104 | + assert np.all(gains[:-1] + 1e-7 >= gains[1:]) |
| 105 | + |
| 106 | + # Early exit on empty input |
| 107 | + idx, gains = dpp(np.empty((0, 3), dtype=np.float32), np.array([]), k=3) |
| 108 | + assert idx.size == 0 and gains.size == 0 |
| 109 | + |
| 110 | + |
| 111 | +@pytest.mark.parametrize( |
| 112 | + "strategy, fn, kwargs", |
| 113 | + [ |
| 114 | + (Strategy.MMR, mmr, {"lambda_param": 0.5, "metric": Metric.COSINE, "normalize": True}), |
| 115 | + (Strategy.MSD, msd, {"lambda_param": 0.5, "metric": Metric.COSINE, "normalize": True}), |
| 116 | + (Strategy.COVER, cover, {"theta": 0.5, "gamma": 0.5}), |
| 117 | + (Strategy.DPP, dpp, {"beta": 0.5}), |
| 118 | + ], |
| 119 | +) |
| 120 | +def test_diversify(strategy: Strategy, fn: Callable, kwargs: Any) -> None: |
| 121 | + """Test the diversify function.""" |
| 122 | + emb = np.eye(4, dtype=np.float32) |
| 123 | + scores = np.array([0.3, 0.7, 0.1, 0.5], dtype=np.float32) |
| 124 | + |
| 125 | + idx_direct, gains_direct = fn(emb, scores, k=2, **kwargs) |
| 126 | + idx_disp, gains_disp = diversify(strategy, embeddings=emb, scores=scores, k=2, **kwargs) |
| 127 | + |
| 128 | + assert np.array_equal(idx_direct, idx_disp) |
| 129 | + assert np.allclose(gains_direct, gains_disp) |
0 commit comments