Dictionaries and Hashmaps: Count Triplets

#!/bin/python3
'''
import math
import os
import random
import re
import sys

# Complete the countTriplets function below.
#naive way
def countTriplets0(arr, r):
    s=sorted(arr)
    count=0
    n=len(arr)
    for i in range(n-2):
        for j in range(i+1,n-1):
            for k in range(j+1,n):
                if s[k]/s[j]==s[j]/s[i]==r:
                    count+=1
    return count

#more efficient way
from collections import defaultdict


def countTriplets(arr, r):
    count=0
    zd=defaultdict(int)
    for s in arr:
        zd[s]+=1
    keys=sorted(list(zd.keys()))
    n=len(keys)

    if r==1:
        for key,value in zd.items():
            if value>=3:
                count+=value*(value-1)*(value-2)/6
    else:
        if n>=3:
            for i in range(n-2):
                for j in range(i+1,n-1):
                    for k in range(j+1,n):
                        if keys[k]/keys[j]==keys[j]/keys[i]==r:
                            count+=zd[keys[k]]*zd[keys[j]]*zd[keys[i]]

    return int(count)


if __name__ == '__main__':
    fptr = open(os.environ['OUTPUT_PATH'], 'w')

    nr = input().rstrip().split()

    n = int(nr[0])

    r = int(nr[1])

    arr = list(map(int, input().rstrip().split()))

    ans = countTriplets(arr, r)

    fptr.write(str(ans) + '\n')

    fptr.close()

'''
import math
import os
import random
import re
import sys
from collections import defaultdict

# Complete the countTriplets function below.
def countTriplets(arr, r):
    res = 0
    pairs = defaultdict(int)
    triplets = defaultdict(int)

    for el in arr:
        res += triplets[el]
        triplets[r*el] += pairs[el]
        pairs[r*el] += 1
        #print("el = {} triplets = {} pairs = {}".format(el, dict(triplets), dict(pairs)))

    return res

if __name__ == '__main__':
    fptr = open(os.environ['OUTPUT_PATH'], 'w')

    nr = input().rstrip().split()

    n = int(nr[0])

    r = int(nr[1])

    arr = list(map(int, input().rstrip().split()))

    ans = countTriplets(arr, r)

    fptr.write(str(ans) + '\n')

    fptr.close()
#'''
发布了163 篇原创文章 · 获赞 90 · 访问量 6276

猜你喜欢

转载自blog.csdn.net/weixin_45405128/article/details/104242583