fork download
  1. #include<bits/stdc++.h>
  2. #define MOD 998244353
  3. #define x(p) p.second.first
  4. #define y(p) p.second.second
  5. #define sqeud(a) (x(a)*x(a)+y(a)*y(a))
  6.  
  7. using namespace std;
  8.  
  9. typedef long long ll;
  10. typedef pair<ll, pair<int, int>> p3;
  11.  
  12. ll modpow(ll a, ll n) {
  13. if(n == 0 or a == 1) return 1;
  14. else if(n == 1) return a;
  15.  
  16. if(n&1) return (a * modpow(a, n-1)) % MOD;
  17.  
  18. ll temp = modpow(a, n / 2);
  19. return (temp*temp) % MOD;
  20. }
  21.  
  22. ll modinv(ll a) {
  23. return modpow(a, MOD-2);
  24. }
  25.  
  26.  
  27.  
  28. int main() {
  29. ios::sync_with_stdio(false);
  30.  
  31.  
  32. #ifdef DBG
  33. freopen("in", "r", stdin);
  34. #endif
  35.  
  36. int n, m, r, c;
  37. cin>>n>>m;
  38.  
  39. vector<p3> arr(n*m);
  40.  
  41. for(int i = 0; i < n; i++) {
  42. for(int j = 0; j < m; j++) {
  43. int ind = i*m + j;
  44. cin>>arr[ind].first;
  45. arr[ind].second = {i + 1, j + 1};
  46. }
  47. }
  48. cin>>r>>c;
  49.  
  50. sort(arr.begin(), arr.end());
  51.  
  52. vector<ll> sqsum(n*m, 0);
  53. vector<ll> xsum(n*m, 0);
  54. vector<ll> ysum(n*m, 0);
  55. vector<ll> vsum(n*m, 0);
  56. // int i;
  57.  
  58. ll v;
  59.  
  60. for(int i = 0; i < n*m; i++) {
  61. if(i == 0) {
  62. sqsum[i] = sqeud(arr[i]);
  63. xsum[i] = x(arr[i]);
  64. ysum[i] = y(arr[i]);
  65. } else {
  66. sqsum[i] = (sqsum[i-1] + sqeud(arr[i]))%MOD;
  67. xsum[i] = xsum[i-1] + x(arr[i]);
  68. ysum[i] = ysum[i-1] + y(arr[i]);
  69. }
  70.  
  71. if(x(arr[i]) == r and y(arr[i]) == c) {
  72. v = arr[i].first;
  73. }
  74. }
  75.  
  76. ll ans = 0ll;
  77.  
  78. for(int i = n*m - 1; i > 0; i--) {
  79. if(arr[i].first > v) continue;
  80. else if(arr[i].first == v and (x(arr[i]) != r or y(arr[i]) != c)) continue;
  81.  
  82. p3 value = {arr[i].first, {-1, -1}};
  83. int ind = lower_bound(arr.begin(), arr.end(), value) - arr.begin();
  84. if(ind == 0) {
  85. break;
  86. }
  87. ll count = ind;
  88. ind--;
  89. ll count_inv = modinv(count);
  90. ll val = (count*sqeud(arr[i]))%MOD;
  91.  
  92. val = (val + sqsum[ind])%MOD;
  93. val -= (2*x(arr[i])*xsum[ind]);
  94. val -= (2*y(arr[i])*ysum[ind]);
  95. while(val < 0) {
  96. val += MOD;
  97. }
  98.  
  99. val = (val * count_inv) % MOD;
  100. ans = (ans + val) % MOD;
  101. }
  102. cout<<ans;
  103. return 0;
  104. }
Success #stdin #stdout 0s 15240KB
stdin
2 3
1 5 7
2 3 1
1 2
stdout
332748123