avutil/tx_template: add cache friendly matrix transpose

This commit is contained in:
Paul B Mahol 2024-09-25 20:46:05 +02:00
parent 5dc71d7193
commit 6ef3eea3c8

View File

@ -1190,12 +1190,39 @@ static void TX_NAME(ff_tx_fft_naive_small)(AVTXContext *s, void *_dst, void *_sr
}
}
static void transpose_matrix(TXComplex *out, const TXComplex *in, int m, int n)
static void transpose_matrix(TXComplex *out, const TXComplex *in,
int row, int col, int n, int m,
const int sn, const int sm)
{
for (int i = 0; i < m; i++, in++) {
for (int j = 0, k = 0; j < n; j++, k+=m)
out[j] = in[k];
out += n;
const int block = 64 / sizeof(*out);
if (m > block || n > block) {
if (n >= m) {
const int half_n = n / 2;
transpose_matrix(out, in, row, col, half_n, m, sn, sm);
transpose_matrix(out, in, row, col+half_n, n-half_n, m, sn, sm);
} else {
const int half_m = m / 2;
transpose_matrix(out, in, row, col, n, half_m, sn, sm);
transpose_matrix(out, in, row+half_m, col, n, m-half_m, sn, sm);
}
} else {
const TXComplex *src = in + row * sn;
const int rlimit = row + m;
const int climit = col + n;
for (int i = row; i < rlimit; i++) {
TXComplex *dst = out + col * sm;
for (int j = col; j < climit; j++) {
dst[i] = src[j];
dst += sm;
}
src += sn;
}
}
}
@ -1217,7 +1244,7 @@ static void TX_NAME(ff_tx_fft_bailey)(AVTXContext *s, void *_dst, void *_src,
stride /= sizeof(*dst);
transpose_matrix(tmp2, src, n, m);
transpose_matrix(tmp2, src, 0, 0, n, m, n, m);
for (int i = 0; i < n; i++) {
fn0(sub0, tmp, tmp2, sizeof(TXComplex));
@ -1233,7 +1260,7 @@ static void TX_NAME(ff_tx_fft_bailey)(AVTXContext *s, void *_dst, void *_src,
CMUL3(tmp[i], x, exp[i]);
}
transpose_matrix(tmp2, tmp, m, n);
transpose_matrix(tmp2, tmp, 0, 0, m, n, m, n);
for (int i = 0; i < m; i++) {
fn1(sub1, tmp, tmp2, sizeof(TXComplex));
@ -1243,7 +1270,7 @@ static void TX_NAME(ff_tx_fft_bailey)(AVTXContext *s, void *_dst, void *_src,
tmp2 = ((TXComplex *)s->exp) + len;
tmp = tmp2 + len;
transpose_matrix(tmp2, tmp, n, m);
transpose_matrix(tmp2, tmp, 0, 0, n, m, n, m);
if (stride == 1) {
memcpy(dst, tmp2, len * sizeof(TXComplex));