Skip to content

Commit 3c4aad5

Browse files
committed
first version of model evaluator
1 parent 505a1cf commit 3c4aad5

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

priv/python/evaluate.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
from sentence_transformers import SentenceTransformer, util
3+
# Used to create and store the Faiss index.
4+
import faiss
5+
import numpy as np
6+
7+
import evaluate_data
8+
import evaluate_model
9+
10+
def predict(model, text):
11+
test_emb = model.encode([text], show_progress_bar=False)
12+
# print('emb', test_emb)
13+
return test_emb[0]
14+
15+
def pair_similarity(model, a, b):
16+
emba = predict(model, a); embb = predict(model, b);
17+
return util.pytorch_cos_sim(emba, embb)
18+
19+
def test_example(model, test_name, test):
20+
close = test['close']
21+
far = test['far']
22+
good_sim = pair_similarity(model, close[0], close[1])
23+
bad_sim = pair_similarity(model, far[0], far[1])
24+
if good_sim < bad_sim:
25+
return [False, good_sim, close, bad_sim, far]
26+
else:
27+
return [True, good_sim, close, bad_sim, far]
28+
29+
def test_group(model_name, model, group_name, test_list):
30+
failures = []
31+
for test_name, test in test_list.items():
32+
result = test_example(model, test_name, test)
33+
[status, close_sim, close_pair, far_sim, far_pair] = result
34+
if not status:
35+
print(f'{model_name} - {group_name} - {test_name} - FAILED')
36+
print(f'{close_pair} -> {close_sim}')
37+
print(f'{far_pair} -> {far_sim}')
38+
failures.append([model_name, group_name, test_name, close_sim, close_pair, far_sim, far_pair])
39+
return failures
40+
41+
def load(model_name):
42+
return SentenceTransformer(model_name)
43+
44+
def start():
45+
model_failures = {}
46+
group_failures = {}
47+
for model_name in evaluate_model.MODEL:
48+
model = load(model_name)
49+
model_failures[model_name] = []
50+
for group_name, tests in evaluate_data.EVAL_GROUPS.items():
51+
group_failures[group_name] = []
52+
53+
failures = test_group(model_name, model, group_name, tests)
54+
55+
for fail in failures:
56+
model_failures[model_name].append(fail)
57+
group_failures[group_name].append(fail)
58+
59+
print(f'model results:')
60+
for model_name in model_failures:
61+
f = len(model_failures[model_name])
62+
print(f'model: {model_name} failures: {f}')
63+
64+
start()

priv/python/evaluate_data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
EVAL_GROUPS = {
3+
4+
'test group 1': {
5+
'test 1': {
6+
'close': ['I like cats', 'I like kittens'],
7+
'far': ['I like cats', 'I like sharp knives and dead bodies']
8+
},
9+
'test 2': {
10+
'close': ['I like cats', 'I like dogs'],
11+
'far': ['I like cats', 'The first president of the United States']
12+
},
13+
}
14+
15+
}
16+
17+

priv/python/evaluate_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
MODEL = [
2+
"paraphrase-mpnet-base-v2",
3+
"paraphrase-multilingual-mpnet-base-v2",
4+
"paraphrase-distilroberta-base-v2",
5+
"paraphrase-MiniLM-L6-v2",
6+
"paraphrase-MiniLM-L3-v2",
7+
"stsb-mpnet-base-v2",
8+
"nli-mpnet-base-v2",
9+
"stsb-distilroberta-base-v2",
10+
"nli-roberta-base-v2",
11+
"stsb-roberta-base-v2",
12+
"nli-distilroberta-base-v2",
13+
"average_word_embeddings_komninos",
14+
"msmarco-distilbert-base-v3",
15+
]

0 commit comments

Comments
 (0)