思路:
点分治
先预处理好1e6 + 3以内到逆元
然后用map 映射以分治点为起点的链的值a 成他的下标 u
然后暴力跑出以分治点儿子为起点的链的值b,然后在map里查找inv[b]*k
代码:
#includeusing namespace std;#define fi first#define se second#define pi acos(-1.0)#define LL long long//#define mp make_pair#define pb push_back#define ls rt<<1, l, m#define rs rt<<1|1, m+1, r#define ULL unsigned LL#define pll pair #define pii pair #define mem(a, b) memset(a, b, sizeof(a))#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);//headconst int MOD = 1e6 + 3;const int INF = 0x7f7f7f7f;const int N = 1e5 + 5;int inv[MOD + 5], mp[MOD + 5], head[N], mxsz[N], sz[N], v[N], cnt = 0, rt = 0, n, k, ans1, ans2;int deep[N], dis[N], id[N], top = 0;bool vis[N];struct edge { int to, nxt;}edge[N*2];void add_edge(int u, int v) { edge[cnt].to = v; edge[cnt].nxt = head[u]; head[u] = cnt++;}void init() { inv[1] = 1; for (int i = 2; i < MOD; i++) inv[i] = (MOD - MOD/i) * 1LL * inv[MOD%i] % MOD;}void update(int x, int y) { int t = (1LL * inv[x] * k) % MOD; int now = mp[t]; if(!now) return ; if(now > y) swap(now, y); if(now < ans1 || now == ans1 && y < ans2) ans1 = now, ans2 = y;}void get_rt(int o, int u) { sz[u] = 1, mxsz[u] = 0; for (int i = head[u]; ~i; i = edge[i].nxt) { if(edge[i].to != o && !vis[edge[i].to]) { get_rt(u, edge[i].to); sz[u] += sz[edge[i].to]; mxsz[u] = max(mxsz[u], sz[edge[i].to]); } } mxsz[u] = max(mxsz[u], n - sz[u]); if(mxsz[u] < mxsz[rt]) rt = u;}void get_d(int o, int u) { deep[++top] = dis[u]; id[top] = u; for (int i = head[u]; ~i; i = edge[i].nxt) { if(!vis[edge[i].to] && edge[i].to != o) { dis[edge[i].to] = (1LL * dis[u] * v[edge[i].to])%MOD; get_d(u, edge[i].to); } }}void solve(int u) { vis[u] = true; mp[v[u]] = u; for (int i = head[u]; ~i; i = edge[i].nxt) { if(!vis[edge[i].to]) { top = 0, dis[edge[i].to] = v[edge[i].to]; get_d(u, edge[i].to); for (int j = 1; j <= top; j++) update(deep[j], id[j]); top = 0, dis[edge[i].to] = (1LL * v[u] * v[edge[i].to])%MOD; get_d(u, edge[i].to); for (int j = 1; j <= top; j++) { int t = deep[j]; if(!mp[t] || id[j] < mp[t]) mp[t] = id[j]; } } } mp[v[u]] = 0; for (int i = head[u]; ~i; i = edge[i].nxt) { if(!vis[edge[i].to]) { top = 0, dis[edge[i].to] = (1LL * v[u] * v[edge[i].to])%MOD; get_d(u, edge[i].to); for (int j = 1; j <= top; j++) mp[deep[j]] = 0; } } for (int i = head[u]; ~i; i = edge[i].nxt) { if(!vis[edge[i].to]) { mxsz[0] = n = sz[edge[i].to]; get_rt(rt = 0, edge[i].to); solve(rt); } }}int main() { init(); int u, V; while(~scanf("%d%d", &n, &k)) { mem(head, -1); mem(vis, false); mem(mp, 0); cnt = 0; ans1 = ans2 = INF; for (int i = 1; i <= n; i++) scanf("%d", &v[i]); for (int i = 1; i < n; i++) scanf("%d%d", &u, &V), add_edge(u, V), add_edge(V, u); mxsz[0] = n; get_rt(rt = 0, 1); solve(rt); if(ans1 == INF) printf("No solution\n"); else printf("%d %d\n", ans1, ans2); } return 0;}