fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. int main() {
  5. ios::sync_with_stdio(false);
  6. cin.tie(nullptr);
  7.  
  8. int N;
  9. if (!(cin >> N)) return 0;
  10.  
  11. vector<vector<pair<int,int>>> g(N+1);
  12. g.reserve(N+1);
  13. for (int i = 0; i < N-1; ++i) {
  14. int u, v; int w;
  15. cin >> u >> v >> w;
  16. g[u].push_back({v, w});
  17. g[v].push_back({u, w});
  18. }
  19.  
  20. // 1) px 계산 (루트 = 1), 비재귀
  21. vector<uint32_t> px(N+1, 0);
  22. vector<char> vis(N+1, 0);
  23. px[1] = 0; vis[1] = 1;
  24. vector<int> st; st.reserve(N);
  25. st.push_back(1);
  26. while (!st.empty()) {
  27. int u = st.back(); st.pop_back();
  28. for (auto [v, w] : g[u]) {
  29. if (!vis[v]) {
  30. vis[v] = 1;
  31. px[v] = px[u] ^ (uint32_t)w;
  32. st.push_back(v);
  33. }
  34. }
  35. }
  36.  
  37. // 2) px 값들을 모아 정렬 후 같은 값 묶음별로 쌍 수 합산
  38. vector<uint32_t> vals; vals.reserve(N);
  39. for (int i = 1; i <= N; ++i) vals.push_back(px[i]);
  40. sort(vals.begin(), vals.end());
  41.  
  42. long long ans = 0; // 최대 ~1.25e11 이므로 64비트
  43. long long cnt = 1;
  44. for (int i = 1; i < N; ++i) {
  45. if (vals[i] == vals[i-1]) ++cnt;
  46. else {
  47. ans += cnt * (cnt - 1) / 2;
  48. cnt = 1;
  49. }
  50. }
  51. ans += cnt * (cnt - 1) / 2;
  52.  
  53. cout << ans << '\n';
  54. return 0;
  55. }
Success #stdin #stdout 0.01s 5308KB
stdin
5
1 2 1
1 3 2
3 4 3
3 5 0
stdout
2