Pytorch 实现 PSD 功率谱计算:periodogram 和 welch 方法

前言

由于跑代码需要,本人用 Pytorch 简单实现了 PSD 估计中的最常用的 periodogramwelch 函数。这里把我的实现分享给大家。

在实现方案和参数的设计上,本人参考了 scipy 的 API 设计。

periodogram 的实现

函数体如下。

def _periodogram(X: torch.Tensor, fs, detrend, scaling):
    if X.dim() > 2:
        X = torch.squeeze(X)
    elif X.dim() == 1:
        X = X.unsqueeze(0)

    if detrend:
        X -= X.mean(-1, keepdim=True)

    N = X.size(-1)
    assert N % 2 == 0

    df = fs / N
    dt = df
    f = torch.arange(0, N / 2 + 1) * df  # [0:df:f/2]

    dual_side = fft.fft(X)  # 双边谱
    half_idx = int(N / 2 + 1)
    single_side = dual_side[:, 0:half_idx]
    win = torch.abs(single_side)

    ps = win ** 2
    if scaling == 'density':  # 计算功率谱密度
        scale = N * fs
    elif scaling == 'spectrum':  # 计算功率谱
        scale = N ** 2
    elif scaling is None: # 不做缩放
        scale = 1
    else:
        raise ValueError('Unknown scaling: %r' % scaling)
    Pxy = ps / scale

    Pxy[:, 1:-1] *= 2  # 能量2倍;直流分量不用二倍, 中心频率点不用二倍

    return f, Pxy.squeeze()


def periodogram(X: torch.Tensor, fs=256, detrend=False, scaling='density', no_grad=True):
    """计算信号单边 PSD, 基本等价于 scipy.signal.periodogram
    
        Parameters:
        ----------
            - `X`:          torch.Tensor, EEG, [T]/[N, T]
            - `fs`:         int, 采样率, Hz
            - `detrend`:    bool, 是否去趋势 (去除直流分量)
            - `scaling`:    { 'density', 'spectrum' }, 可选
                - 'density':    计算功率谱密度 `(V ** 2 / Hz)`
                - 'spectrum':    计算功率谱 `(V ** 2)`
            - `no_grad`:    bool, 是否启用 no_grad() 模式
    
        Returns:
        ----------
            - `Pxy`:    Tensor, 单边功率谱
    """
    if no_grad:
        with torch.no_grad():
            return _periodogram(X, fs, detrend, scaling)
    else:
        return _periodogram(X, fs, detrend, scaling)

welch 的实现

def _get_window(window, nwlen, device):
    if window == 'hann':
        window = torch.hann_window(
            nwlen, dtype=torch.float32, device=device, periodic=False
        )
    elif window == 'hamming':
        window = torch.hamming_window(
            nwlen, dtype=torch.float32, device=device, periodic=False
        )
    elif window == 'blackman':
        window = torch.blackman_window(
            nwlen, dtype=torch.float32, device=device, periodic=False
        )
    elif window == 'boxcar':
        window = torch.ones(nwlen, dtype=torch.float32, device=device)
    else:
        raise ValueError('Invalid Window {}' % window)
    return window


def _pwelch(X: torch.Tensor, fs, detrend, scaling, window, nwlen, nhop):
    X = ensure_3dim(X)
    if scaling == 'density':
        scale = (fs * (window * window).sum().item())
    elif scaling == 'spectrum':
        scale = window.sum().item() ** 2
    else:
        raise ValueError('Unknown scaling: %r' % scaling)
    # --------------- Fold and windowing --------------- #
    N, T = X.size(0), X.size(-1)
    X = X.view(N, 1, 1, T)
    X_fold = F.unfold(X, (1, nwlen), stride=nhop)  # [N, 1, 1, T] -> [N, nwlen, win_cnt]
    if detrend:
        X_fold -= X_fold.mean(1, keepdim=True) # 各个窗口各自detrend
    window = window.view(1, -1, 1)  # [1, nwlen, 1]
    X_windowed = X_fold * window  # [N, nwlen, win_cnt]
    win_cnt = X_windowed.size(-1)

    # --------------- Pwelch --------------- #
    X_windowed = X_windowed.transpose(1, 2).contiguous()  # [N, win_cnt, nwlen]
    X_windowed = X_windowed.view(N * win_cnt, nwlen)  # [N * win_cnt, nwlen]
    f, pxx = _periodogram(
        X_windowed, fs, detrend=False, scaling=None
    )  # [N * win_cnt, nwlen // 2 + 1]
    pxx /= scale
    pxx = pxx.view(N, win_cnt, -1)  # [N, win_cnt, nwlen // 2 + 1]
    pxx = torch.mean(pxx, dim=1)  # [N, nwlen // 2 + 1]
    return f, pxx


def pwelch(
    X: torch.Tensor,
    fs=256,
    detrend=False,
    scaling='density',
    window='hann',
    nwlen=128,
    nhop=None,
    no_grad=True,
):
    """Pwelch 方法,大致相当于 scipy.signal.welch

        Parameters:
        ----------
            - `X`:          torch.Tensor, EEG, [T]/[N, T]
            - `fs`:         int, 采样率, Hz
            - `detrend`:    bool, 是否去趋势 (去除直流分量)
            - `scaling`:    { 'density', 'spectrum' }, 可选
                - 'density':    计算功率谱密度 `(V ** 2 / Hz)`
                - 'spectrum':    计算功率谱 `(V ** 2)`
            - `window`:     str, 窗函数名称
            - `nwlen`:      int, 窗函数长度 (点的个数)
            - `nhop`:       int, 窗函数移动步长, 即 nwlen - noverlap (点的个数)
                            如果为 None,则默认为 `nwlen // 2`
            - `no_grad`:    bool, 是否启用 no_grad() 模式
    
        Returns:
        ----------
            - `Pxy`:    Tensor, 单边功率谱
    """
    nhop = nwlen // 2 if nhop is None else nhop
    window = _get_window(window, nwlen, X.device)
    if no_grad:
        with torch.no_grad():
            return _pwelch(X, fs, detrend, scaling, window, nwlen, nhop)
    else:
        return _pwelch(X, fs, detrend, scaling, window, nwlen, nhop)

测试

导入数据

在这里插入图片描述

周期图法

periodogram

welch 法

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/frostime/article/details/120870904
今日推荐