Skip to content

Commit 5e9bdcc

Browse files
Santos SafraoSantos Safrao
authored andcommitted
compute canonical angles
1 parent 321ab12 commit 5e9bdcc

File tree

6 files changed

+316
-125
lines changed

6 files changed

+316
-125
lines changed

computeSubspacesSimilarities.m

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
function similarities = computeSubspacesSimilarities(X, Y, varargin)
2+
% computeSubspacesSimilarities: Compute similarities between sets of subspaces.
3+
% This function calculates the similarities between two sets of subspaces represented
4+
% by multi-dimensional arrays X and Y. The similarity is defined based on the squared
5+
% singular values of the cross-covariance matrix of the subspaces.
6+
%
7+
% Parameters:
8+
% X: A num_dim_x x num_sub_dim_x x num_sets_x 3D matrix representing the first set of subspaces.
9+
% - num_dim_x: dimension of the original vector space.
10+
% - num_sub_dim_x: dimension of each subspace in X.
11+
% - num_sets_x: number of subspaces in X.
12+
%
13+
% Y: A num_dim_y x num_sub_dim_y x num_sets_y 3D matrix representing the second set of subspaces.
14+
% - num_dim_y: dimension of the original vector space.
15+
% - num_sub_dim_y: dimension of each subspace in Y.
16+
% - num_sets_y: number of subspaces in Y.
17+
%
18+
% varargin: A string flag that determines the method of computation.
19+
% - If not specified or empty, similarities are based on the average of cos^2 theta,
20+
% computed efficiently using the trace of the cross-covariance matrix.
21+
% - If set to 'F', the function uses Singular Value Decomposition (SVD) to compute all
22+
% canonical angles and outputs both the average and the cos^2 of the first theta.
23+
%
24+
% Return:
25+
% similarities: A 4D matrix containing the computed similarities.
26+
% - size: num_sets_x x num_sets_y x num_sub_dim_x x num_sub_dim_y.
27+
% - Each element (i, j, k, l) represents the similarity between the k-th subspace of the
28+
% i-th set in X and the l-th subspace of the j-th set in Y.
29+
% - if vargin is set to 'F', similarities becomes a 3D matrix of size
30+
% num_sets_x x num_sets_y x min(num_sub_dim_x, num_sub_dim_y).
31+
%
32+
% Example Usage:
33+
% sim = computeSubspacesSimilarities(X, Y);
34+
% % Computes the subspace similarities between X and Y using the default method.
35+
%
36+
% sim = computeSubspacesSimilarities(X, Y, 'F');
37+
% % Computes the subspace similarities using SVD for all canonical angles.
38+
%
39+
% Notes:
40+
% - The function normalizes the similarity by the number of singular values used in the computation.
41+
% - For high-dimensional data, using the default method (without 'F') is computationally more efficient.
42+
%
43+
% Last Update: 2023/10 by Santos Enoque
44+
% Computer Vision Laboratory, University of Tsukuba
45+
% http://www.cvlab.cs.tsukuba.ac.jp/
46+
47+
x_size = size(X);
48+
x_size = length(x_size);
49+
y_size = size(Y);
50+
y_size = length(y_size);
51+
% Check if the size of any of the subspaces is less than 3 or greater than 4
52+
if any(x_size < 3 | x_size > 4) || any(y_size < 3 | y_size > 4)
53+
error('The size of each subspace must be greater than or equal to 3 and less than or equal to 4.');
54+
end
55+
56+
use_svd = false;
57+
58+
% Check if an additional input argument is provided
59+
if nargin == 3
60+
% If the additional input argument is 'F', set the flag for using SVD to true
61+
if varargin{1} == 'F'
62+
use_svd = true;
63+
end
64+
end
65+
% Ensure X and Y are 3D matrices, even if they are 4D initially
66+
X = X(:,:,:);
67+
Y = Y(:,:,:);
68+
69+
% Get dimensions information of input subspaces X and Y
70+
[~, num_dim_x, num_sets_x] = size(X);
71+
[~, num_dim_y, num_sets_y] = size(Y);
72+
73+
% If SVD is not used, compute similarities using cross-covariance
74+
if ~use_svd
75+
% Compute the squared cross-covariance matrix and reshape it to a 4D matrix
76+
cross_cov_squared = reshape((X(:,:)' * Y(:, :)).^2, num_dim_x, num_sets_x, num_dim_y, num_sets_y);
77+
78+
% Cumulatively sum the squared cross-covariance matrix along dimensions
79+
similarities = cumsum(cumsum(cross_cov_squared, 1), 3);
80+
81+
% Normalize the similarities by the minimum dimension between X and Y subspaces
82+
for i = 1:num_dim_x
83+
for j = 1:num_dim_y
84+
similarities(i, :, j, :) = similarities(i, :, j, :) / min([i, j]);
85+
end
86+
end
87+
88+
% Rearrange dimensions of the similarities matrix for consistent output format
89+
similarities = permute(similarities, [2, 4, 1, 3]);
90+
91+
% If SVD is used, compute similarities using singular values
92+
else
93+
% Compute the cross-covariance matrix and reshape it to a 4D matrix
94+
cross_cov = reshape((X(:,:)' * Y(:, :)), num_dim_x, num_sets_x, num_dim_y, num_sets_y);
95+
96+
% Rearrange dimensions of the cross-covariance matrix for consistent computation
97+
cross_cov = permute(cross_cov, [1, 3, 2, 4]);
98+
99+
% Initialize the similarities matrix with zeros
100+
num_dim = min([num_dim_x, num_dim_y]);
101+
similarities = zeros(num_sets_x, num_sets_y, num_dim);
102+
103+
% Compute cumulative sum of squared singular values for each pair of subspaces
104+
for i = 1:num_sets_x
105+
for j = 1:num_sets_y
106+
similarities(i, j, :) = cumsum(svd(cross_cov(:, :, i, j)).^2, 1);
107+
end
108+
end
109+
110+
% Normalize the similarities by the dimensionality
111+
for i = 1:num_dim
112+
similarities(:, :, i) = similarities(:, :, i) / i;
113+
end
114+
end
115+
end

run_unit_tests.m

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,50 @@
1-
clc
2-
functions_to_test = {'computePCA', 'computeBasisVectors'};
3-
num_of_functions = numel(functions_to_test);
1+
clc;
42

5-
for i = 1:num_of_functions
6-
% Capitalize the first letter of the function name
7-
func_name_capitalized = functions_to_test{i};
8-
func_name_capitalized(1) = upper(func_name_capitalized(1));
3+
% List of functions to test
4+
functions_to_test = {'ComputePCA',...
5+
'ComputeBasisVectors',...
6+
'ComputeSubspacesSimilarities'};
7+
8+
functions_to_test = functions_to_test(1, end);
9+
10+
% Run tests for each function
11+
for func_name = functions_to_test
12+
runTestsForFunction(char(func_name));
13+
end
14+
15+
function runTestsForFunction(func_name)
16+
try
17+
% Generate the test function name based on the function name
18+
test_function_name = ['test', func_name];
19+
20+
% Trim leading and trailing whitespace
21+
test_function_name = strtrim(test_function_name);
22+
23+
% Check if the test function exists
24+
if exist(test_function_name, 'file') ~= 2
25+
error('Test function "%s" does not exist.', test_function_name);
26+
end
27+
28+
% Run the test function
29+
results = runtests(test_function_name);
30+
31+
% Display results
32+
displayTestResults(results, func_name);
33+
catch ME
34+
fprintf('Error running tests for "%s": %s\n', func_name, ME.message);
35+
end
36+
end
37+
38+
function displayTestResults(results, func_name)
39+
fprintf('\nResults for "%s":\n', func_name);
40+
fprintf('----------------------------------------\n');
941

10-
% Create the test function name
11-
test_function_name = ['test', func_name_capitalized];
42+
if all([results.Passed])
43+
fprintf('All %d tests passed.\n', numel(results));
44+
else
45+
fprintf('%d out of %d tests failed.\n', sum([results.Failed]), numel(results));
46+
disp(results([results.Failed]));
47+
end
1248

13-
% Run the test function
14-
results = runtests(test_function_name);
15-
disp(results)
49+
fprintf('----------------------------------------\n');
1650
end
17-

testComputeBasisVectors.m

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,58 @@
1-
function tests = testComputeBasisVectors
2-
tests = functiontests(localfunctions);
3-
end
4-
5-
function testPositiveIntegerNumSubDim(testCase)
6-
X = rand(10, 20, 5);
7-
num_sub_dim = 5;
8-
basis_vectors = computeBasisVectors(X, num_sub_dim);
9-
expected_size = [10, 5, 5];
10-
verifySize(testCase, basis_vectors, expected_size);
11-
end
1+
classdef testComputeBasisVectors < matlab.unittest.TestCase
2+
3+
methods (Test)
4+
function testPositiveIntegerNumSubDim(testCase)
5+
X = rand(10, 20, 5);
6+
num_sub_dim = 5;
7+
basis_vectors = computeBasisVectors(X, num_sub_dim);
8+
expected_size = [10, 5, 5];
9+
verifySize(testCase, basis_vectors, expected_size);
10+
end
1211

13-
function testRatioNumSubDim(testCase)
14-
X = rand(10, 20, 5);
15-
num_sub_dim = 0.95;
16-
basis_vectors = computeBasisVectors(X, num_sub_dim);
17-
expected_size = [10, 20, 5];
18-
actual_size = size(basis_vectors);
19-
verifyLessThanOrEqual(testCase, actual_size(2), expected_size(2));
20-
end
12+
function testRatioNumSubDim(testCase)
13+
X = rand(10, 20, 5);
14+
num_sub_dim = 0.95;
15+
basis_vectors = computeBasisVectors(X, num_sub_dim);
16+
expected_size = [10, 20, 5];
17+
actual_size = size(basis_vectors);
18+
verifyLessThanOrEqual(testCase, actual_size(2), expected_size(2));
19+
end
2120

22-
function testInvalidNumSubDim(testCase)
23-
X = rand(10, 20, 5);
24-
num_sub_dim = -5;
25-
verifyError(testCase, @() computeBasisVectors(X, num_sub_dim), '');
26-
end
21+
function testInvalidNumSubDim(testCase)
22+
X = rand(10, 20, 5);
23+
num_sub_dim = -5;
24+
verifyError(testCase, @() computeBasisVectors(X, num_sub_dim), '');
25+
end
2726

28-
function testInvalidRatioNumSubDim(testCase)
29-
X = rand(10, 20, 5);
30-
num_sub_dim = 1.5;
31-
verifyError(testCase, @() computeBasisVectors(X, num_sub_dim), '');
32-
end
27+
function testInvalidRatioNumSubDim(testCase)
28+
X = rand(10, 20, 5);
29+
num_sub_dim = 1.5;
30+
verifyError(testCase, @() computeBasisVectors(X, num_sub_dim), '');
31+
end
3332

34-
function test4DInput(testCase)
35-
X = rand(10, 20, 3, 5);
36-
num_sub_dim = 5;
37-
basis_vectors = computeBasisVectors(X, num_sub_dim);
38-
expected_size = [10, 5, 3, 5];
39-
verifySize(testCase, basis_vectors, expected_size);
40-
end
33+
function test4DInput(testCase)
34+
X = rand(10, 20, 3, 5);
35+
num_sub_dim = 5;
36+
basis_vectors = computeBasisVectors(X, num_sub_dim);
37+
expected_size = [10, 5, 3, 5];
38+
verifySize(testCase, basis_vectors, expected_size);
39+
end
4140

42-
function testOrthogonalityOfBasisVectors(testCase)
43-
X = rand(10, 20, 5);
44-
num_sub_dim = 5;
45-
basis_vectors = computeBasisVectors(X, num_sub_dim);
46-
47-
num_sets = size(basis_vectors, 3);
48-
49-
for i = 1:num_sets
50-
basis_set = basis_vectors(:,:,i);
51-
dot_product_matrix = basis_set' * basis_set;
52-
identity_matrix = eye(size(dot_product_matrix));
53-
% Check if the dot product matrix is approximately an identity matrix
54-
verifyEqual(testCase, dot_product_matrix, identity_matrix, 'AbsTol', 1e-10);
41+
function testOrthogonalityOfBasisVectors(testCase)
42+
X = rand(10, 20, 5);
43+
num_sub_dim = 5;
44+
basis_vectors = computeBasisVectors(X, num_sub_dim);
45+
46+
num_sets = size(basis_vectors, 3);
47+
48+
for i = 1:num_sets
49+
basis_set = basis_vectors(:,:,i);
50+
dot_product_matrix = basis_set' * basis_set;
51+
identity_matrix = eye(size(dot_product_matrix));
52+
% Check if the dot product matrix is approximately an identity matrix
53+
verifyEqual(testCase, dot_product_matrix, identity_matrix, 'AbsTol', 1e-10);
54+
end
55+
end
5556
end
57+
5658
end
57-

0 commit comments

Comments
 (0)