L Nice Trick

题目描述

Given n integers a1,a2,…,ana_1, a_2, \dots, a_na1​,a2​,…,an​, Bobo knows how to compute the sum of triples S3=∑1≤i<j<k≤naiajak.S_3 = \sum_{1 \leq i < j < k \leq n} a_i a_j a_k.S3​=∑1≤i<j<k≤n​ai​aj​ak​.
It follows that S3=(∑1≤i≤nai)3−3(∑1≤i≤nai2)(∑1≤i≤nai)+2(∑1≤i≤nai3)6.S_3 = \frac{(\sum_{1 \leq i \leq n} a_i)^3 - 3 (\sum_{1 \leq i \leq n} a_i^2)(\sum_{1 \leq i \leq n} a_i) + 2(\sum_{1 \leq i \leq n} a_i^3)}{6}.S3​=6(∑1≤i≤n​ai​)3−3(∑1≤i≤n​ai2​)(∑1≤i≤n​ai​)+2(∑1≤i≤n​ai3​)​.

Bobo would like to compute the sum of quadrangles (∑1≤i<j<k<l≤naiajakal) mod (109+7).\left(\sum_{1 \leq i < j < k < l \leq n} a_i a_j a_k a_l\right)\bmod (10^9+7).(∑1≤i<j<k<l≤n​ai​aj​ak​al​)mod(109+7).

输入描述:

The input contains zero or more test cases and is terminated by end-of-file. For each test case,
The first line contains an integer n.
The second line contains n integers a1,a2,…,ana_1, a_2, \dots, a_na1​,a2​,…,an​.

  • 1≤n≤1051 \leq n \leq 10^51≤n≤105
  • 0≤ai≤1090 \leq a_i \leq 10^90≤ai​≤109
  • The number of tests cases does not exceed 10.

输出描述:

For each case, output an integer which denotes the result.

示例1
输入

3
1 2 3
4
1 2 3 4
5
1 2 3 4 5

输出
0
24
274

题目已经给了三项式的求法,求四项式的时候只需要在三项式的基础上乘以一位就行了,然后遍历需要乘以的那一位。
最后,不要忘了在结果上加mod模mod。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#define mem(a, b) memset(a, b, sizeof(a))
using namespace std;


const int maxn = 1e5 + 10;
typedef long long ll;
const long long  mod = 1e9 + 7;
ll a[maxn], sum2[maxn], sum3[maxn], sum[maxn];

ll exgcd(ll a, ll b, ll &x, ll &y)
{
	if(b == 0)
	{
		x = 1, y = 0;
		return a;
	}
	ll d = exgcd(b, a % b, x, y);
	ll t = x;
	x = y;
	y = t - a / b * y;
	return d;
}

ll getinv(ll a, ll p)
{
	ll x, y;
	ll d = exgcd(a, p, x, y);
	return (x % p + p) % p;
}

ll solve(int n)
{
	ll par1 = 1;
	for(int i = 0; i < 3; i++)
		(par1 *= sum[n] % mod) %= mod;
	ll par2 = sum2[n] % mod * sum[n] % mod * 3 % mod;
	ll par3 = 2 * sum3[n] % mod;
	ll ans = (par1 - par2 % mod + par3) % mod;
	int inv = getinv(6, mod);
	ans = (ans % mod) * inv % mod;
	return ans;
}

void init()
{
	mem(a, 0), mem(sum, 0);
	mem(sum2, 0), mem(sum3, 0);
}

int main()
{
	int n;
	while(scanf("%d", &n) != EOF)
	{
		init();
		for(int i = 1; i <= n;i++)
			scanf("%lld", &a[i]), a[i] %= mod;
		for(int i =1 ;i <= n; i++)
		{
			ll a2 = 1, a3 = 1;
			for(int j = 0; j < 3; j++)
			{
				if(j < 2)
					(a2 *= a[i] ) %= mod, a3 = a2;
				else
					(a3 *= a[i]) %= mod;
			}
			(sum[i] += sum[i - 1] + a[i] % mod) %= mod;
			(sum2[i] += (sum2[i - 1] + a2) % mod) %= mod;
			(sum3[i] += (sum3[i - 1] + a3) % mod) %= mod;
		}
		ll ans = 0;
		ll sum = 0;
		for(int i = 4; i <= n; i++)
			(ans += solve(i - 1) * a[i] % mod) %= mod;
		(ans += mod) %= mod;
		printf("%lld\n", ans);
	}
	return 0;
}
发布了73 篇原创文章 · 获赞 15 · 访问量 8104

猜你喜欢

转载自blog.csdn.net/ln2037/article/details/102382752