Spectrum Compression

import soundfile as sf
import torch
from scipy.io import wavfile
import os
import numpy as np

audio_file_save = "/home*/wav/1.wav"
clean, s = sf.read("/home/*/clean_trainset_wav/p287_424.wav")
x1 = torch.stft(torch.tensor(clean), n_fft=512, hop_length=100, win_length=400, return_complex=False)# 1,256,641,2
x1 = x1.unsqueeze(dim=0)
x1 = x1.permute(0, 3, 2, 1)  # 1,2,641,256

# # conduct sqrt power-compression
x1_mag, x1_phase = torch.norm(x1, dim=1) ** 0.5, torch.atan2(x1[:, -1, ...], x1[:, 0, ...])  # 1,641,256; 1,641,256
x1 = torch.stack((x1_mag * torch.cos(x1_phase), x1_mag * torch.sin(x1_phase)), dim=1)  # 1,2,641,256
# pad = torch.nn.ZeroPad2d((0, 1, 0, 0))
# x1 = pad(x1)
out1 = x1.permute(0, 3, 2, 1)
out1 = torch.istft(out1, 512, hop_length=100, win_length=400, return_complex=False).unsqueeze(1)
out1 = out1.numpy()
wavfile.write(audio_file_save,  16000, out1.astype(np.float32))

2.测试code


import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from math import sqrt, log10, ceil
from pystoi import stoi
from pesq import pesq
from pesq.cypesq import PesqError
from pypesq import pesq as pesq_mos
import argparse
from mir_eval.separation import bss_eval_sources
from pit_criterion import cal_loss
from collections import OrderedDict
import numpy as np
import torch
from data_whamr import TestDataset, TestCollate
from D2Net.mc_power_compression_D2Net import Net
from torch.utils.data import DataLoader
# from D2Net.mc_D2NET import Net
# from D2Net.mc_bss_D2NET import Net
import os
import torch.nn as nn


def remove_pad(inputs, inputs_lengths):
    """
    Args:
        inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
        inputs_lengths: torch.Tensor, [B]
    Returns:
        results: a list containing B items, each item is [C, T], T varies
    """
    results = []
    dim = inputs.dim()
    if dim == 3:
        C = inputs.size(1)
    for input, length in zip(inputs, inputs_lengths):
        if dim == 3:  # [B, C, T]
            results.append(input[:, :length].view(C, -1).cpu().numpy())
        elif dim == 2:  # [B, T]
            results.append(input[:length].view(-1).cpu().numpy())
    return results


parser = argparse.ArgumentParser('Evaluate separation performance using FaSNet + TAC')
# parser.add_argument('--test_dir', type=str, default="/home/weiWB/dataset/FaSNet_TAC/or_0.25_spk1_snr/MC_Libri_adhoc/test/2mic/", help='path to test/2mic/samples')
parser.add_argument('--test_dir', type=str,
                    default="/home/wangLS/dataset/mc_whamr/mc_whamr_mix_5/mc_whamr_tt_5R_list",
                    help='path to test/2mic/samples')
parser.add_argument('--model_path', type=str, default='/home/weiWB/code/FaSNet-TAC-PyTorch20/D2Net/exp/tmp/mc__power_compression_D2Net_whamr5/temp_best.pth.tar',
                    help='Path to model file created by training')
parser.add_argument('--cal_sdr', type=int, default=0,
                    help='Whether calculate SDR, add this option because calculation of SDR is very slow')
parser.add_argument('--use_cuda', type=int, default=1, help='Whether use GPU to separate speech')

# General config
# Task related
parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate')

# Network architecture

parser.add_argument('--mic', default=2, type=int, help='number of microphone')


class ISTFT(nn.Module):
    def __init__(self, n_fft=512, hop_length=100, window_length=400):
        super(ISTFT, self).__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.window_length = window_length
        self.Istft = torch.istft
        self.pad = torch.nn.ZeroPad2d((0, 1, 0, 0))

    def forward(self, x):
        out = self.pad(x)  # 1,4,641,257
        out = out.permute(0, 3, 2, 1)  # 1,2,641,257
        out = torch.istft(out, n_fft=self.n_fft, hop_length=self.hop_length,
                          win_length=self.window_length,
                          return_complex=False).unsqueeze(1)
        # out = out.unsqueeze(1)
        return out

def evaluate(args):
    eps = 1e-8
    total_SISNRi = 0
    total_SISNR = 0
    total_wb_pesq = 0
    total_nb_pesq_mos = 0
    total_nb_pesq = 0
    total_stoi = 0
    total_SDRi = 0
    total_sdr = 0
    total_cnt = 0

    # load D2Net model
    model = Net()

    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    # model.load_state_dict(torch.load(args.model_path, map_location='cpu'))

    model_info = torch.load(args.model_path)
    try:
        model.load_state_dict(model_info['model_state_dict'])
    except KeyError:
        state_dict = OrderedDict()
        for k, v in model_info['model_state_dict'].items():
            name = k.replace("module.", "")  # remove 'module.'
            state_dict[name] = v
        model.load_state_dict(state_dict)

    print(model)
    model.eval()


    # whamr data_loader
    test_data = TestDataset(args.test_dir)
    data_loader = DataLoader(test_data,
                             batch_size=1,
                             shuffle=False,
                             num_workers=0,
                             collate_fn=TestCollate()
                             )

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data

            if args.use_cuda:
                padded_mixture = padded_mixture.cuda()  # tensor ([1,2,64000])
                mixture_lengths = mixture_lengths.cuda()  # tensor[64000,64000]
                padded_source = padded_source.cuda()  # tensor ([1,2,64000])


            # D2Net
            estimate_source = model(padded_mixture)

            # mc_power_compression_D2Net
            padded_source = torch.split(padded_source, 1, dim=1)
            padded_source = padded_source[0]  # B,1,L



            istft = ISTFT()
            estimate_source = istft(estimate_source)

            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)

            M, _, T = padded_mixture.shape  # B,C,T
            mixture_ref = torch.chunk(padded_mixture, args.mic, dim=1)[0]  # [B, 1, T]
            mixture_ref = mixture_ref.view(M, T)  # [ 1, T]

            mixture = remove_pad(mixture_ref, mixture_lengths)  # ndarray(T,)
            source = remove_pad(padded_source, mixture_lengths)  # ndarray(C,T)
            estimate_source = remove_pad(reorder_estimate_source, mixture_lengths)  # ndarray(C,T)


            # for each utterance  mix:ndarray(T,);src_ref:(C,T);src_est:(C,T)
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    # print("\tSDRi={0:.2f}".format(avg_SDRi))

                src_ref = np.reshape(src_ref, [-1])  # ndarray(128000,)
                src_est = np.reshape(src_est, [-1])  # ndarray(128000,)


                # 2-D numpy to 1-D numpy compute pesq
                # pesq
                avg_wb_pesq = pesq(args.sample_rate, src_ref, src_est, on_error=PesqError.RETURN_VALUES)  # pesq_batch求多输入输出的pesq

                # pystoi
                avg_stoi = stoi(src_ref, src_est, args.sample_rate)





                total_wb_pesq += avg_wb_pesq
                total_stoi += avg_stoi
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi / (total_cnt + eps)))

    print("Average wb_pesq improvement: {0:.2f}".format(total_wb_pesq / (total_cnt + eps)))
    print("Average stoi improvement: {0:.2f}".format(total_stoi / (total_cnt + eps)))
    


def cal_SDRi(src_ref, src_est, mix):
    """Calculate Source-to-Distortion Ratio improvement (SDRi).
    NOTE: bss_eval_sources is very very slow.
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SDRi
    """
    src_anchor = np.stack([mix, mix], axis=0)
    sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
    sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
    avg_SDRi = ((sdr[0] - sdr0[0]) + (sdr[1] - sdr0[1])) / 2.0
    # print("SDRi1: {0:.2f}, SDRi2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[1]))
    return avg_SDRi


def cal_SISNRi(src_ref, src_est, mix):
    """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SISNRi
    """
    sisnr1 = cal_SISNR(src_ref[0], src_est[0])
    sisnr2 = cal_SISNR(src_ref[1], src_est[1])
    sisnr1b = cal_SISNR(src_ref[0], mix)
    sisnr2b = cal_SISNR(src_ref[1], mix)
    # print("SISNR base1 {0:.2f} SISNR base2 {1:.2f}, avg {2:.2f}".format(
    #     sisnr1b, sisnr2b, (sisnr1b+sisnr2b)/2))
    # print("SISNRi1: {0:.2f}, SISNRi2: {1:.2f}".format(sisnr1, sisnr2))
    avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2.0


    return avg_SISNRi


def cal_SISNR(ref_sig, out_sig, eps=1e-8):
    """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
    Args:
        ref_sig: numpy.ndarray, [T]
        out_sig: numpy.ndarray, [T]
    Returns:
        SISNR
    """
    assert len(ref_sig) == len(out_sig)
    ref_sig = ref_sig - np.mean(ref_sig)
    out_sig = out_sig - np.mean(out_sig)
    ref_energy = np.sum(ref_sig ** 2) + eps
    proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
    noise = out_sig - proj
    ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
    sisnr = 10 * np.log(ratio + eps) / (np.log(10.0) + eps)
    return sisnr


def cal_si_snr(source, estimate_source):
    """Calculate SI-SNR.
    Arguments:
    ---------
    source: [T, B, C],
        Where B is batch size, T is the length of the sources, C is the number of sources
        the ordering is made so that this loss is compatible with the class PitWrapper.
    estimate_source: [T, B, C]
        The estimated source.
    Example:
    ---------
     import numpy as np
     x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
     xhat = x[:, (1, 0)]
     x = x.unsqueeze(-1).repeat(1, 1, 2)
     xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
     si_snr = -cal_si_snr(x, xhat)
     print(si_snr)
    tensor([[[ 25.2142, 144.1789],
             [130.9283,  25.2142]]])
    """
    EPS = 1e-8
    # assert source.size() == estimate_source.size()
    # device = estimate_source.device.type

    # source_lengths = torch.tensor(
    #     [estimate_source.shape[0]] * estimate_source.shape[-2], device=device
    # )
    source_lengths = torch.tensor(
        [estimate_source.shape[0]] * estimate_source.shape[-2]
    )
    source = torch.tensor(source)
    estimate_source = torch.tensor(estimate_source)
    mask = get_mask(source, source_lengths)
    estimate_source *= mask

    num_samples = (
        source_lengths.contiguous().reshape(1, -1, 1).float()
    )  # [1, B, 1]
    mean_target = torch.sum(source, dim=0, keepdim=True) / num_samples
    mean_estimate = (
            torch.sum(estimate_source, dim=0, keepdim=True) / num_samples
    )
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    # mask padding position along T
    zero_mean_target *= mask
    zero_mean_estimate *= mask

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = zero_mean_target  # [T, B, C]
    s_estimate = zero_mean_estimate  # [T, B, C]
    # s_target = <s', s>s / ||s||^2
    dot = torch.sum(s_estimate * s_target, dim=0, keepdim=True)  # [1, B, C]
    s_target_energy = (
            torch.sum(s_target ** 2, dim=0, keepdim=True) + EPS
    )  # [1, B, C]
    proj = dot * s_target / s_target_energy  # [T, B, C]
    # e_noise = s' - s_target
    e_noise = s_estimate - proj  # [T, B, C]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    si_snr_beforelog = torch.sum(proj ** 2, dim=0) / (
            torch.sum(e_noise ** 2, dim=0) + EPS
    )
    si_snr = 10 * torch.log10(si_snr_beforelog + EPS)  # [B, C]

    return -si_snr.unsqueeze(0)


def get_mask(source, source_lengths):
    """
    Arguments
    ---------
    source : [T, B, C]
    source_lengths : [B]
    Returns
    -------
    mask : [T, B, 1]
    Example:
    ---------
     source = torch.randn(4, 3, 2)
     source_lengths = torch.Tensor([2, 1, 4]).int()
     mask = get_mask(source, source_lengths)
     print(mask)
    tensor([[[1.],
             [1.],
             [1.]],
    <BLANKLINE>
            [[1.],
             [0.],
             [1.]],
    <BLANKLINE>
            [[0.],
             [0.],
             [1.]],
    <BLANKLINE>
            [[0.],
             [0.],
             [1.]]])
    """
    mask = source.new_ones(source.size()[:-1]).unsqueeze(-1).transpose(1, -2)
    B = source.size(-2)
    for i in range(B):
        mask[source_lengths[i]:, i] = 0
    return mask.transpose(-2, 1)


def SDR(est, egs, mix):
    '''
        calculate SDR
        est: Network generated audio
        egs: Ground Truth
        mix:
    '''
    length = est.shape[0]
    sdr, _, _, _ = bss_eval_sources(egs[:length], est[:length])
    mix_sdr, _, _, _ = bss_eval_sources(egs[:length], mix[:length])
    return float(sdr - mix_sdr)


# ssnr
def calc_ssnr(signal, noise, frame_size):
    """
    Calculate segmental signal noise ratio.
    If file is not noisy then SNR is about 100dB.
    :param signal: (list)
    :param noise: (list)
    :param frame_size: (int) ssnr frame size
    :return: (value) SSNR(dB)
    """
    if len(signal) != len(noise):
        raise Exception("ERROR: Signal noise size mismatch")
    number_of_frame_size = ceil(len(signal) / frame_size)
    sum = 0
    nonzero_frame_number = 0
    segmental_signal_power = [0] * number_of_frame_size
    segmental_noise_power = [0] * number_of_frame_size
    for i in range(number_of_frame_size):
        if i == number_of_frame_size - 1:
            segmental_signal_power[i] = calc_power(signal[frame_size * i:])
            segmental_noise_power[i] = calc_power(noise[frame_size * i:])
        else:
            segmental_signal_power[i] = calc_power(signal[frame_size * i:frame_size * (i + 1)])
            segmental_noise_power[i] = calc_power(noise[frame_size * i:frame_size * (i + 1)])
        if segmental_noise_power[i] == 0:
            segmental_noise_power[i] = pow(0.1, 10)
        if segmental_signal_power[i] != 0:
            nonzero_frame_number += 1
            sum += 10 * log10(segmental_signal_power[i] / segmental_noise_power[i])
    ssnr = sum / nonzero_frame_number

    return ssnr


def calc_power(input):
    """
    Calculate power of input.
    :param input: (list)
    :return: (value)
    """
    sum = 0
    for n in input:
        sum += pow(n, 2)
    return sum / len(input)


# SSNR
def SNRseg(clean_speech, processed_speech, fs, frameLen=0.03, overlap=0.75):
    eps = np.finfo(np.float64).eps

    winlength = round(frameLen * fs)  # window length in samples
    skiprate = int(np.floor((1 - overlap) * frameLen * fs))  # window skip in samples
    MIN_SNR = -10  # minimum SNR in dB
    MAX_SNR = 35  # maximum SNR in dB

    hannWin = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
    clean_speech_framed = extractOverlappedWindows(clean_speech, winlength, winlength - skiprate, hannWin)
    processed_speech_framed = extractOverlappedWindows(processed_speech, winlength, winlength - skiprate, hannWin)

    signal_energy = np.power(clean_speech_framed, 2).sum(-1)
    noise_energy = np.power(clean_speech_framed - processed_speech_framed, 2).sum(-1)

    segmental_snr = 10 * np.log10(signal_energy / (noise_energy + eps) + eps)
    segmental_snr[segmental_snr < MIN_SNR] = MIN_SNR
    segmental_snr[segmental_snr > MAX_SNR] = MAX_SNR
    segmental_snr = segmental_snr[:-1]  # remove last frame -> not valid
    return np.mean(segmental_snr)


def extractOverlappedWindows(x, nperseg, noverlap, window=None):
    # source: https://github.com/scipy/scipy/blob/v1.2.1/scipy/signal/spectral.py
    step = nperseg - noverlap
    shape = x.shape[:-1] + ((x.shape[-1] - noverlap) // step, nperseg)
    strides = x.strides[:-1] + (step * x.strides[-1], x.strides[-1])
    result = np.lib.stride_tricks.as_strided(x, shape=shape,
                                             strides=strides)
    if window is not None:
        result = window * result
    return result


if __name__ == '__main__':
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 指定使用第几个显卡
    args = parser.parse_args()
    print(args)
    evaluate(args)

猜你喜欢

转载自blog.csdn.net/qq_42019881/article/details/127039999