Given an integer array nums, please find and return the maximum sum of elements that are divisible by three.
Example 1:
Input: nums = [3,6,5,1,8]
Output: 18
Explanation: Choose the numbers 3, 6, 1 and 8 whose sum is 18 (the largest sum divisible by 3).
Example 2:
Input: nums = [4]
Output: 0
Explanation: 4 is not divisible by 3, so the number cannot be selected, return 0.
Example 3:
Input: nums = [1,2,3,4,4]
Output: 12
Explanation: Select the numbers 1, 3, 4 and 4 whose sum is 12 (the largest sum divisible by 3).
hint:
1 <= nums.length <= 4 * 10^4
1 <= nums[i] <= 10^4
Link: https://leetcode-cn.com/problems/greatest-sum-divisible-by-three
Thought analysis:
Input: nums = [3,6,5,1,8]
Output: 18
Explanation: Choose the numbers 3, 6, 1 and 8 whose sum is 18 (the largest sum divisible by 3).
Analysis of this example shows that this is a 01 knapsack problem. Each number can only be selected once.
However, there are still 1 and 8 selected in the sample. This is because 1%3 = 1, 8%3 = 2. That is, for the largest sum divisible by 3, the largest sum divisible by 1 and 2 also needs to be considered, and each number can only be used once.
Thus, the state comes out:
dp[i][j] represents the maximum sum of modulo 3 remainder j among the first i numbers.
At the beginning, dp[0][1] and dp[0][2] are both initialized to INT_MIN.
The figure below shows the dp table.
State transition equation:
dp[i][j] = max(dp[i-1][j], dp[i-1][*] + nums[i-1]); ( *depending on the situation)
such as t%3 == 2
For dp[i][0], the first two states can be recursively obtained:
- dp[i-1][1] + t
- dp[i-1][0]
The maximum value from these two states is the maximum sum of %3==0 in the first i numbers.
In the above formula, we also need to ensure that dp[i-1][1] is the largest, so we need to write 3 dp and discuss them in categories.
The code above:
class Solution {
public:
int maxSumDivThree(vector<int>& nums) {
int n = nums.size();
int dp[n+2][3];
memset(dp,0,sizeof(dp));
dp[0][0] = 0; dp[0][1] = INT_MIN; dp[0][2] = INT_MIN;
for(int i = 1;i <= n;i++)
{
int t = nums[i-1]%3;
if(t == 0)
{
dp[i][0] = max(dp[i-1][0], dp[i-1][0] + nums[i-1]);
dp[i][1] = max(dp[i-1][1], dp[i-1][1] + nums[i-1]);
dp[i][2] = max(dp[i-1][2], dp[i-1][2] + nums[i-1]);
}
else if(t == 1)
{
dp[i][0] = max(dp[i-1][0], dp[i-1][2] + nums[i-1]);
dp[i][1] = max(dp[i-1][1], dp[i-1][0] + nums[i-1]);
dp[i][2] = max(dp[i-1][2], dp[i-1][1] + nums[i-1]);
}
else if(t == 2)
{
dp[i][0] = max(dp[i-1][0], dp[i-1][1] + nums[i-1]);
dp[i][1] = max(dp[i-1][1], dp[i-1][2] + nums[i-1]);
dp[i][2] = max(dp[i-1][2], dp[i-1][0] + nums[i-1]);
}
}
for(int i = 1;i <= n;i++)
{
for(int j = 0;j <= 2;j++)
{
cout<<dp[i][j]<<" ";
}
cout<<endl;
}
return dp[n][0];
}
};