Skip to content

Commit 125c344

Browse files
Santos SafraoSantos Safrao
authored andcommitted
Mutual Subspace Method
1 parent dd6a4c6 commit 125c344

File tree

9 files changed

+262
-11
lines changed

9 files changed

+262
-11
lines changed

data/CVLABFace2.mat

9.49 MB
Binary file not shown.

examples/msm.m

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("data/CVLABFace2.mat")
2+
training_data = X1;
3+
testing_data = X2;
4+
5+
[~, num_samples, ~] = size(training_data);
6+
[num_dim, num_samples_per_set, num_sets, num_classes] = size(testing_data);
7+
8+
num_dim_reference_subspaces = 20;
9+
num_dim_input_subpaces = 5;
10+
11+
reference_subspaces = computeBasisVectors(training_data, num_dim_reference_subspaces);
12+
input_subspaces = computeBasisVectors(testing_data, num_dim_input_subpaces);
13+
similarities = computeSubspacesSimilarities(reference_subspaces, input_subspaces);
14+
15+
model_evaluation = ModelEvaluation(similarities(:, :, end, end), generateLabels(size(testing_data, 3), num_classes));
16+
17+
displayModelResults('Mutual Subspace Methods', model_evaluation);

install.m

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
addpath(genpath('..')) % Add the subfolders of the project to the path
2-
addPackageFoldersToPath

src/classes/@ModelEvaluation/ModelEvaluation.m

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
classdef ModelEvaluation
22
properties
3+
accuracy; % Accuracy
34
error_rate; % Overall error rate
45
equal_error_rate; % Equal error rate (point where FAR and FRR are approximately equal)
56
classification_threshold; % Threshold used for classification
@@ -43,11 +44,13 @@
4344
else
4445
[~, predicted_labels] = min(evaluation_values, [], 1);
4546
end
46-
47-
obj.error_rate = 1 - mean(predicted_labels == labels);
47+
obj.accuracy = mean(predicted_labels == labels);
48+
obj.error_rate = 1 - obj.accuracy;
4849
else
4950
binary_labels = zeros(size(labels));
5051
binary_labels(labels ~= 0) = 1;
52+
predicted_labels = evaluation_values >= obj.classification_threshold;
53+
obj.accuracy = mean(predicted_labels == binary_labels);
5154
end
5255

5356
evaluation_values = evaluation_values(:);

src/classes/@OrzEval/OrzEval.m

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
2+
classdef OrzEval
3+
properties (SetAccess = public)
4+
ER;
5+
EER;
6+
Thres;
7+
8+
A;
9+
FAR;
10+
FRR;
11+
12+
nP;
13+
nN;
14+
15+
flgSIM;
16+
17+
end% properties
18+
19+
methods
20+
function OB = OrzEval(VAL, Label, varargin)
21+
%function OB = OrzEval(VAL, Label, varargin)
22+
% VAL: �ގ��x�������͔�ގ��x�i�����j���������s��A�������͍s�x�N�g��
23+
% VAL���s��̏ꍇ�A���N���X���i�Q�N���X�ȏ�j�Ɣ��f
24+
% VAL���s�x�N�g���̏ꍇ�A�P�N���X�Ɣ��f��ER���v�Z���Ȃ�
25+
% Label: VAL�̗񐔂Ɠ����T�C�Y�̍s�x�N�g��
26+
% VAL�̐������x����ێ�
27+
% ���N���X���̏ꍇ�A�P�`�N���X���̒l
28+
% �P�N���X���̏ꍇ�A�P�iPositive�j�ƂO�iNegative�j
29+
% ��O�����F VAL�̒l���ގ��x����ގ��x�i�����j�����肷��
30+
% �f�t�H���g�ł́A�ގ��x
31+
% ����'D'����O�����ɓ��͂��ꂽ�ꍇ�A��ގ��x�i�����j�Ƃ��Čv�Z
32+
%
33+
% PlotEER�F False Reject Rate��False Alarm Rate��Figure(10)�ɕ`��
34+
% �����ɂ��A�ԍ���ύX�”\
35+
% PlotROC�F ROC curve��Figure(100)�ɕ`��
36+
% �����ɂ��A�ԍ���ύX�”\
37+
38+
VAL=VAL(:,:);
39+
% �ގ��x����ގ��x��
40+
OB.flgSIM=true;
41+
if nargin == 3
42+
if varargin{1}=='D';
43+
OB.flgSIM=false;
44+
end
45+
end
46+
47+
% One-Class ��肩�ǂ���
48+
if size(VAL,1)>=2
49+
B=zeros(size(VAL));
50+
Lu = unique(Label);
51+
for I=1:size(Lu,2)
52+
B(I,Label==Lu(I))=1;
53+
end
54+
55+
if OB.flgSIM
56+
[v ind] = max(VAL,[],1);
57+
else
58+
[v ind] = min(VAL,[],1);
59+
end
60+
OB.ER = 1-mean(ind == Label);
61+
62+
else
63+
B = zeros(size(Label));
64+
B(Label~=0)=1;
65+
end
66+
VAL=VAL(:);
67+
B=B(:);
68+
69+
OB.nP = sum(B==1);
70+
OB.nN = sum(B==0);
71+
72+
if OB.flgSIM
73+
[OB.A C]= sort(VAL,'ascend');
74+
else
75+
[OB.A C]= sort(VAL,'descend');
76+
end
77+
D = B(C);
78+
79+
OB.FAR = 1-cumsum(D==0)/OB.nN;
80+
OB.FRR = cumsum(D==1)/OB.nP;
81+
82+
[val ind] = min((abs(OB.FAR-OB.FRR)));
83+
OB.EER = (OB.FAR(ind)+OB.FRR(ind))/2;
84+
OB.Thres = OB.A(ind);
85+
end
86+
87+
function PlotEER(OB,varargin)
88+
if nargin == 2
89+
No = varargin{1};
90+
else
91+
No = 10;
92+
end
93+
94+
figure(No)
95+
clf;
96+
hold on
97+
plot(OB.A,OB.FRR,'b');
98+
plot(OB.A,OB.FAR,'r');
99+
title('FRR - FAR');
100+
legend('False Reject Rate','False Alarm Rate' );
101+
xlabel('Threshold')
102+
ylabel('Rate')
103+
hold off
104+
end
105+
106+
function PlotROC(OB,varargin)
107+
if nargin == 2
108+
No = varargin{1};
109+
color = 'r';
110+
elseif nargin == 3
111+
No = varargin{1};
112+
color = varargin{2};
113+
else
114+
No = 100;
115+
color = 'r';
116+
end
117+
118+
figure(No)
119+
%clf;
120+
hold on
121+
plot(OB.FAR,1-OB.FRR,color);
122+
xlabel('False Positive Rate')
123+
ylabel('True Positive Rate')
124+
hold off
125+
end
126+
end
127+
end
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
function displayModelResults(model_name, model_evaluation)
2+
% This function displays the model evaluation results.
3+
%
4+
% Parameters:
5+
% - model_name: Name of the model being evaluated.
6+
% - model_evaluation: Object of ModelEvaluation class containing evaluation results.
7+
8+
% Assertions to check input types
9+
assert(ischar(model_name) || isstring(model_name), 'The model name should be a string.');
10+
assert(isa(model_evaluation, 'ModelEvaluation'), 'The model_evaluation should be an instance of the ModelEvaluation class.');
11+
12+
% Display the model name
13+
fprintf('\nModel: %s\n', model_name);
14+
15+
% Display the results
16+
disp('---------- Model Evaluation Results ----------');
17+
fprintf('Model Accuracy: %.2f%%\n', model_evaluation.accuracy * 100); % Display as percentage
18+
fprintf('Model Error Rate: %.2f%%\n', model_evaluation.error_rate * 100); % Display as percentage
19+
fprintf('Equal Error Rate (EER): %.2f%%\n', model_evaluation.equal_error_rate * 100); % Display as percentage
20+
fprintf('Classification Threshold: %.2f\n', model_evaluation.classification_threshold);
21+
disp('----------------------------------------------');
22+
end

src/functions/generateLabels.m

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
function labels = generateLabels(num_samples, varargin)
2+
% generateLabels Generates a vector of labels for dataset samples.
3+
%
4+
% This function creates a vector of integer labels for a dataset based on
5+
% the specified number of samples, classes, and sets. Labels are assigned
6+
% in a repeating sequential manner across the classes for each set.
7+
%
8+
% Parameters:
9+
% num_samples (integer): The number of samples in each class.
10+
% varargin: Optional input arguments to specify number of classes and sets.
11+
% - If one additional argument is provided, it specifies the number of classes.
12+
% - If two additional arguments are provided, the first specifies the number
13+
% of classes, and the second specifies the number of sets.
14+
%
15+
% Returns:
16+
% labels (1 x N integer array): The generated vector of labels, where N is the
17+
% total number of samples across all classes and sets.
18+
%
19+
% Usage:
20+
% labels = generateLabels(100);
21+
% % Generates 100 labels, all assigned to a single class and set.
22+
%
23+
% labels = generateLabels(100, 5);
24+
% % Generates 500 labels for 5 classes, each class having 100 samples,
25+
% % all within a single set.
26+
%
27+
% labels = generateLabels(100, 5, 3);
28+
% % Generates 1500 labels for 5 classes and 3 sets, each class having
29+
% % 100 samples in each set.
30+
%
31+
% Errors:
32+
% - An error is thrown if the number of input arguments is less than 1 or more than 3.
33+
%
34+
% Example:
35+
% labels = generateLabels(10, 3, 2)
36+
% % Returns: [1 2 3 1 2 3 1 2 3 1 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3]
37+
38+
% Validate the number of input arguments
39+
if nargin < 1
40+
error('Not enough input arguments.');
41+
elseif nargin == 1
42+
num_classes = 1; % Set default number of classes to 1
43+
num_sets = 1; % Set default number of sets to 1
44+
elseif nargin == 2
45+
num_classes = varargin{1}; % Get number of classes from input
46+
num_sets = 1; % Set default number of sets to 1
47+
elseif nargin == 3
48+
num_classes = varargin{1}; % Get number of classes from input
49+
num_sets = varargin{2}; % Get number of sets from input
50+
else
51+
error('Too many input arguments.');
52+
end
53+
54+
% Create a matrix of class labels, repeating each class label 'num_samples' times
55+
% and then repeating the whole sequence for each set.
56+
A = repmat(1:num_classes, [num_samples, num_sets]);
57+
58+
% Convert the matrix of class labels into a single row vector
59+
labels = A(:)';
60+
end
61+

src/functions/orzLabel.m

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
function L = orzLabel(nNum,varargin)
2+
3+
4+
if nargin <1
5+
error('error');
6+
end
7+
if nargin == 1
8+
nClass = 1;
9+
nSet = 1;
10+
end
11+
12+
if nargin == 2
13+
nClass = varargin{1};
14+
nSet = 1;
15+
end
16+
17+
if nargin == 3
18+
nClass = varargin{1};
19+
nSet = varargin{2};
20+
end
21+
22+
if nargin > 3
23+
error('error');
24+
end
25+
26+
A=repmat(1:nClass,[nNum,nSet]);
27+
L=A(:)';
28+

untitled.m

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,2 @@
1-
% These are the similarity scores for each data point, representing likelihood of belonging to class 1
2-
evaluation_values = [0.9, 0.8, 0.2, 0.1, 0.9, 0.7];
3-
% 1 for class 1 and 2 for class 2
4-
labels = [1, 1, 2, 2, 1, 1];
5-
model_eval = ModelEvaluation(evaluation_values, labels);
6-
disp("Error Rate: " + model_eval.error_rate);
7-
disp("EER: " + model_eval.equal_error_rate);
8-
1+
L = generateLabels(10,3, 2);
2+
L

0 commit comments

Comments
 (0)