Skip to content

Commit 9e3eba0

Browse files
committed
covariance: addition of complex
1 parent 340309e commit 9e3eba0

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

src/stdlib_experimental_stats_cov.fypp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ contains
2020
, merge(size(x, 1), size(x, 2), mask = 1<dim))
2121

2222
integer :: i
23-
real(${k1}$) :: mean_(merge(size(x, 1), size(x, 2), mask = 1<dim))
24-
real(${k1}$) :: center(size(x, 1),size(x, 2))
23+
${t1}$ :: mean_(merge(size(x, 1), size(x, 2), mask = 1<dim))
24+
${t1}$ :: center(size(x, 1),size(x, 2))
2525

2626
if (.not.optval(mask, .true.)) then
2727
res = ieee_value(1._${k1}$, ieee_quiet_nan)
@@ -37,7 +37,7 @@ contains
3737
#:if t1[0] == 'r'
3838
res = matmul( transpose(center), center)
3939
#:else
40-
error stop
40+
res = matmul( transpose(conjg(center)), center)
4141
#:endif
4242
case(2)
4343
do i = 1, size(x, 2)
@@ -46,7 +46,7 @@ contains
4646
#:if t1[0] == 'r'
4747
res = matmul( center, transpose(center))
4848
#:else
49-
error stop
49+
res = matmul( center, transpose(conjg(center)))
5050
#:endif
5151
case default
5252
call error_stop("ERROR (mean): wrong dimension")

src/tests/stats/test_cov.f90

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,20 @@ program test_moment
1313
2._dp, 4._dp, 6._dp, 8._dp,&
1414
9._dp, 10._dp, 11._dp, 12._dp], [4, 3])
1515

16+
complex(dp) :: ds(2,3) = reshape([ cmplx(1._dp, 0._dp),&
17+
cmplx(0._dp, 2._dp),&
18+
cmplx(3._dp, 0._dp),&
19+
cmplx(0._dp, 4._dp),&
20+
cmplx(5._dp, 0._dp),&
21+
cmplx(0._dp, 6._dp)], [2, 3])
22+
1623

1724
call test_dp(d)
25+
1826
call test_int32(int(d, int32))
1927

28+
call test_cdp(ds)
29+
2030
contains
2131

2232
subroutine test_dp(x2)
@@ -102,4 +112,41 @@ subroutine test_int32(x2)
102112
, 'int32 check 6')
103113

104114
end subroutine test_int32
115+
116+
subroutine test_cdp(x2)
117+
complex(dp), intent(in) :: x2(:, :)
118+
119+
call check( any(ieee_is_nan(cov(x2, 1, mask = .false.)))&
120+
, 'cdp check 1')
121+
call check( any(ieee_is_nan(cov(x2, 2, mask = .false.)))&
122+
, 'cdp check 2')
123+
124+
125+
call check( all( abs( cov(x2, 1) - reshape([&
126+
2.5_dp, 5.5_dp, 8.5_dp, 5.5_dp, 12.5_dp&
127+
, 19.5_dp, 8.5_dp, 19.5_dp, 30.5_dp]&
128+
,[ size(x2, 2), size(x2, 2)])&
129+
) < dptol)&
130+
, 'cdp check 3')
131+
call check( all( abs( cov(x2, 2) - reshape([&
132+
4._dp, 0._dp, 0._dp, 4._dp]&
133+
,[ size(x2, 1), size(x2, 1)])&
134+
) < dptol)&
135+
, 'cdp check 4')
136+
137+
call check( all( abs( cov(x2, 1, corrected=.false.) - reshape([&
138+
2.5_dp, 5.5_dp, 8.5_dp, 5.5_dp, 12.5_dp&
139+
, 19.5_dp, 8.5_dp, 19.5_dp, 30.5_dp]&
140+
*(size(x2, 1)-1._dp)/size(x2, 1)&
141+
,[ size(x2, 2), size(x2, 2)])&
142+
) < dptol)&
143+
, 'cdp check 5')
144+
call check( all( abs( cov(x2, 2, corrected=.false.) - reshape([&
145+
4._dp, 0._dp, 0._dp, 4._dp]&
146+
*(size(x2, 2)-1._dp)/size(x2, 2)&
147+
,[ size(x2, 1), size(x2, 1)])&
148+
) < dptol)&
149+
, 'cdp check 6')
150+
151+
end subroutine test_cdp
105152
end program test_moment

0 commit comments

Comments
 (0)