#include <benchmark/benchmark.h>
#include <vector>
#include <random>
using namespace std;
const int MOD = 998'244'353;
int add (int a, int b) { return a + b - (a + b < MOD ? 0 : MOD); }
int sub (int a, int b) { return a - b + (a - b >= 0 ? 0 : MOD); }
int mul (int a, int b) { return 1LL * a * b % MOD; }
struct Matrix : vector<int> {
// initialization
int n, m;
Matrix (int n, int m) :
vector<int>(n * m), n(n), m(m) {}
Matrix (initializer_list<int> init, int row) :
n(row), m(init.size() / n), vector<int>(init.begin(), init.end()) {}
// access operators for different scenarios
int* operator[] (int i) { return data() + i * m; }
const int* operator[] (int i) const { return const_cast<int*>(data()) + i * m; }
};
static void matMulOriginal (benchmark::State &state) {
int n = state.range(0);
Matrix a(n, n), b(n, n);
mt19937 rng(21);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
a[i][j] = rng() % MOD, b[i][j] = rng() % MOD;
for (auto _ : state) {
Matrix c(n, n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
c[i][j] = add(c[i][j], mul(a[i][k], b[k][j]));
benchmark::DoNotOptimize(c.data());
benchmark::ClobberMemory();
}
}
BENCHMARK(matMulOriginal)
->RangeMultiplier(2)
->Range(1 << 1, 1 << 10);
static void matMulTranspose (benchmark::State &state) {
int n = state.range(0);
Matrix a(n, n), b(n, n);
mt19937 rng(21);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
a[i][j] = rng() % MOD, b[i][j] = rng() % MOD;
for (auto _ : state) {
Matrix bT(n, n), c(n, n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++) bT[i][j] = b[j][i];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
c[i][j] = add(c[i][j], mul(a[i][k], bT[j][k]));
benchmark::DoNotOptimize(c.data());
benchmark::ClobberMemory();
}
}
BENCHMARK(matMulTranspose)
->RangeMultiplier(2)
->Range(1 << 1, 1 << 10);
const int TILESIZE = 16;
int bCached[TILESIZE][TILESIZE];
static void matMulTiling (benchmark::State &state) {
int n = state.range(0);
Matrix a(n, n), b(n, n);
mt19937 rng(21);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
a[i][j] = rng() % MOD, b[i][j] = rng() % MOD;
for (auto _ : state) {
Matrix c(a.n, b.m);
for (int iTile = 0; iTile < a.n; iTile += TILESIZE) {
int iSize = min(TILESIZE, a.n - iTile);
for (int jTile = 0; jTile < b.m; jTile += TILESIZE) {
int jSize = min(TILESIZE, b.m - jTile);
for (int kTile = 0; kTile < a.m; kTile += TILESIZE) {
int kSize = min(TILESIZE, a.m - kTile);
// transfer data to be cached for b + in-place transpose
for (int k = 0; k < kSize; k++)
for (int j = 0; j < jSize; j++)
bCached[j][k] = b[k + kTile][j + jTile];
// perform matrix multiplication for current block
for (int i = 0; i < iSize; i++) {
// dot product between 2 cached rows
for (int j = 0; j < jSize; j++) {
unsigned long long hold = c[i + iTile][j + jTile];
for (int k = 0; k < kSize; k++)
hold += 1ULL * a[i + iTile][k + kTile] * bCached[j][k];
hold %= MOD, c[i + iTile][j + jTile] = hold;
}
}
}
}
}
benchmark::DoNotOptimize(c.data());
benchmark::ClobberMemory();
}
}
BENCHMARK(matMulTiling)
->RangeMultiplier(2)
->Range(1 << 1, 1 << 10);
BENCHMARK_MAIN();