|
5 | 5 |
|
6 | 6 | submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
|
7 | 7 | use stdlib_linalg_blas, only: gemm
|
| 8 | +use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR |
8 | 9 | use stdlib_constants
|
9 | 10 | implicit none
|
10 | 11 |
|
| 12 | +character(len=*), parameter :: this = "stdlib_matmul" |
| 13 | + |
11 | 14 | contains
|
12 | 15 |
|
13 | 16 | ! Algorithm for the optimal parenthesization of matrices
|
@@ -122,41 +125,76 @@ contains
|
122 | 125 |
|
123 | 126 | end function matmul_chain_mult_${s}$_4
|
124 | 127 |
|
125 |
| -pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r) |
| 128 | +pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err) |
| 129 | +${t}$, intent(out), allocatable :: res(:,:) |
126 | 130 | ${t}$, intent(in) :: m1(:,:), m2(:,:)
|
127 | 131 | ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
|
128 |
| -${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:) |
| 132 | +type(linalg_state_type), intent(out), optional :: err |
| 133 | +${t}$, allocatable :: temp(:,:), temp1(:,:) |
129 | 134 | integer :: p(6), num_present, m, n, k
|
130 | 135 | integer, allocatable :: s(:,:)
|
131 | 136 |
|
| 137 | +type(linalg_state_type) :: err0 |
| 138 | + |
132 | 139 | p(1) = size(m1, 1)
|
133 | 140 | p(2) = size(m2, 1)
|
134 | 141 | p(3) = size(m2, 2)
|
135 | 142 |
|
| 143 | +if (size(m1, 2) /= p(2)) then |
| 144 | +err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m1, m2 not of compatible sizes') |
| 145 | +call linalg_error_handling(err0, err) |
| 146 | +allocate(res(0, 0)) |
| 147 | +return |
| 148 | +end if |
| 149 | + |
136 | 150 | num_present = 2
|
137 | 151 | if (present(m3)) then
|
| 152 | + |
| 153 | +if (size(m3, 1) /= p(3)) then |
| 154 | +err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m2, m3 not of compatible sizes') |
| 155 | +call linalg_error_handling(err0, err) |
| 156 | +allocate(res(0, 0)) |
| 157 | +return |
| 158 | +end if |
| 159 | + |
138 | 160 | p(3) = size(m3, 1)
|
139 | 161 | p(4) = size(m3, 2)
|
140 | 162 | num_present = num_present + 1
|
141 | 163 | end if
|
142 | 164 | if (present(m4)) then
|
| 165 | + |
| 166 | +if (size(m4, 1) /= p(4)) then |
| 167 | +err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m3, m4 not of compatible sizes') |
| 168 | +call linalg_error_handling(err0, err) |
| 169 | +allocate(res(0, 0)) |
| 170 | +return |
| 171 | +end if |
| 172 | + |
143 | 173 | p(4) = size(m4, 1)
|
144 | 174 | p(5) = size(m4, 2)
|
145 | 175 | num_present = num_present + 1
|
146 | 176 | end if
|
147 | 177 | if (present(m5)) then
|
| 178 | + |
| 179 | +if (size(m5, 1) /= p(5)) then |
| 180 | +err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m4, m5 not of compatible sizes') |
| 181 | +call linalg_error_handling(err0, err) |
| 182 | +allocate(res(0, 0)) |
| 183 | +return |
| 184 | +end if |
| 185 | + |
148 | 186 | p(5) = size(m5, 1)
|
149 | 187 | p(6) = size(m5, 2)
|
150 | 188 | num_present = num_present + 1
|
151 | 189 | end if
|
152 | 190 |
|
153 |
| -allocate(r(p(1), p(num_present + 1))) |
| 191 | +allocate(res(p(1), p(num_present + 1))) |
154 | 192 |
|
155 | 193 | if (num_present == 2) then
|
156 | 194 | m = p(1)
|
157 | 195 | n = p(3)
|
158 | 196 | k = p(2)
|
159 |
| -call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m) |
| 197 | +call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, res, m) |
160 | 198 | return
|
161 | 199 | end if
|
162 | 200 |
|
@@ -166,10 +204,10 @@ contains
|
166 | 204 | s = matmul_chain_order(p(1: num_present + 1))
|
167 | 205 |
|
168 | 206 | if (num_present == 3) then
|
169 |
| -r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4)) |
| 207 | +res = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4)) |
170 | 208 | return
|
171 | 209 | else if (num_present == 4) then
|
172 |
| -r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5)) |
| 210 | +res = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5)) |
173 | 211 | return
|
174 | 212 | end if
|
175 | 213 |
|
@@ -182,7 +220,7 @@ contains
|
182 | 220 | m = p(1)
|
183 | 221 | n = p(6)
|
184 | 222 | k = p(2)
|
185 |
| -call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) |
| 223 | +call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, res, m) |
186 | 224 | case (2)
|
187 | 225 | ! (m1*m2)*(m3*m4*m5)
|
188 | 226 | m = p(1)
|
@@ -195,7 +233,7 @@ contains
|
195 | 233 |
|
196 | 234 | k = n
|
197 | 235 | n = p(6)
|
198 |
| -call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) |
| 236 | +call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res, m) |
199 | 237 | case (3)
|
200 | 238 | ! (m1*m2*m3)*(m4*m5)
|
201 | 239 | temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)
|
@@ -208,18 +246,35 @@ contains
|
208 | 246 |
|
209 | 247 | k = m
|
210 | 248 | m = p(1)
|
211 |
| -call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) |
| 249 | +call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res, m) |
212 | 250 | case (4)
|
213 | 251 | ! (m1*m2*m3*m4)*m5
|
214 | 252 | temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
|
215 | 253 | m = p(1)
|
216 | 254 | n = p(6)
|
217 | 255 | k = p(5)
|
218 |
| -call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m) |
| 256 | +call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, res, m) |
219 | 257 | case default
|
220 |
| -error stop "stdlib_matmul: error: unexpected s(i,j)" |
| 258 | +error stop "stdlib_matmul: internal error: unexpected s(i,j)" |
221 | 259 | end select
|
222 | 260 |
|
| 261 | +end subroutine stdlib_matmul_sub_${s}$ |
| 262 | + |
| 263 | +pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r) |
| 264 | +${t}$, intent(in) :: m1(:,:), m2(:,:) |
| 265 | +${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) |
| 266 | +${t}$, allocatable :: r(:,:) |
| 267 | + |
| 268 | +call stdlib_matmul_sub(r, m1, m2, m3, m4, m5) |
| 269 | +end function stdlib_matmul_pure_${s}$ |
| 270 | + |
| 271 | +module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r) |
| 272 | +${t}$, intent(in) :: m1(:,:), m2(:,:) |
| 273 | +${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) |
| 274 | +type(linalg_state_type), intent(out) :: err |
| 275 | +${t}$, allocatable :: r(:,:) |
| 276 | + |
| 277 | +call stdlib_matmul_sub(r, m1, m2, m3, m4, m5, err=err) |
223 | 278 | end function stdlib_matmul_${s}$
|
224 | 279 |
|
225 | 280 | #:endfor
|
|
0 commit comments