Lora in speech synthesis, plug-in speaker development

Lora in speech synthesis, plug-in speaker development, the future of voice cloning~

Paper link:

https://arxiv.org/abs/2211.00585

Paper title:

Adapter-Based Extension of Multi-Speaker Text-to-Speech Model for New Speakers

issuing time:

November 1, 2022

Article contribution:

First, we pre-train a base multi-speaker TTS model on a large and diverse TTS dataset. To extend model for new speakers, we add a few adapters – small modules to the base model. We used vanilla adapter [15], unified adapters [16, 17, 18], or BitFit [19]. Then, we freeze the pre-trained model and fine-tune only adapters on new speaker data.

• We propose a new adapter-based framework for efficient tuning of TTS model for new speakers without forgetting previously learned speakers.

A TTS model based on the adapter framework is proposed, which can realize new speakers through fine-tuning, and will not forget the existing speakers in the model.

• We validate our design through comprehensive ablation study across different types of adapters modules, amounts of training data, and recording conditions.

Validation schemes through comprehensive ablation: different types of adapter modules, large amounts of training data, and even recording conditions.

• We demonstrate that adapter-based TTS tuning performs similarly to full fine-tuning while demanding significantly less compute and data.

We demonstrate that adapter-based TTS tuning has the same quality as full-model fine-tuning while requiring much less computation and data.

Core architecture:

The proposed pipeline for adaptation of multi-speaker TTS model for new speakers.

(a) Pre-train a multi-speaker FastPitch model.

Multi-speaker pre-trained model

(b) Freeze weights of pre-trained FastPitch model and add adapter modules.

Freeze pre-trained model weights, add adapter

(c) Only the adapters are fine-tuned for new speaker.

For new speakers, only the adapter is trained during fine-tuning

(d) Inference by sharing the same model and plugging the lightweight, speaker-specific module.

During inference, insert speaker-specific adapter weights into the pre-trained model: plug-in speaker

Architecture of proposed multi-speaker FastPitch.It is composed of phoneme encoder, mel decoder, duration and pitch predictor, aligner, and speaker encoder. We control speaker identity by using conditional layer normalization (CLN) and concatenating inputs with speaker representation.

Illustration of parameter-efficient tuning modules in transformer architecture. LoRA and Prefix Tuning are only used in FFTs while Adapter and BitFit can be applied to any components in FastPitch.

Code implementation (partial display):

https://github.com/NVIDIA/NeMo/pull/6416

Adds FastPitch pre-training with CLNs and fine-tuning with adapters.

Changelog

  • Adds multi-speaker FastPitch pre-training with Conditional Layer Normalization

  • nemo/collections/tts/modules/fastpitch.py

  • nemo/collections/tts/modules/transformer.py

  • nemo/collections/tts/modules/submodules.py

  • Add adapter modules for FastPitch fine-tuning

  • nemo/collections/tts/models/fastpitch.py

  • nemo/collections/tts/modules/fastpitch.py

  • nemo/collections/tts/modules/transformer.py

  • nemo/collections/tts/modules/aligner.py

  • nemo/collections/tts/parts/mixins/__init__.py

  • nemo/collections/tts/parts/mixins/fastpitch_adapter_mixins.py

  • Add config and fine-tuning python script

  • examples/tts/conf/fastpitch_speaker_adaptation.yaml

  • examples/tts/fastpitch_finetune_adapters.py

  • Fix aligner 

nan loss bug

  • nemo/collections/tts/losses/aligner_loss.py

nemo/collections/tts/modules/adapters.py

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.
from typing import List, Optional
from omegaconf import DictConfig
from nemo.collections.asr.parts.utils import adapter_utilsfrom nemo.collections.tts.modules.aligner import AlignmentEncoderfrom nemo.collections.tts.modules.fastpitch import TemporalPredictorfrom nemo.collections.tts.modules.transformer import FFTransformerDecoder, FFTransformerEncoderfrom nemo.core.classes import adapter_mixins

class FFTransformerDecoderAdapter(FFTransformerDecoder, adapter_mixins.AdapterModuleMixin):    """ Inherit from FFTransformerDecoder and add support for adapter"""
    def add_adapter(self, name: str, cfg: dict):        cfg = self._update_adapter_cfg_input_dim(cfg)        for fft_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin            fft_layer.add_adapter(name, cfg)
    def is_adapter_available(self) -> bool:        return any([FFT_layer.is_adapter_available() for FFT_layer in self.layers])
    def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):        for FFT_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin            FFT_layer.set_enabled_adapters(name=name, enabled=enabled)
    def get_enabled_adapters(self) -> List[str]:        names = set([])        for FFT_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin            names.update(FFT_layer.get_enabled_adapters())
        names = sorted(list(names))        return names
    def _update_adapter_cfg_input_dim(self, cfg: DictConfig):        cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model)        return cfg

class FFTransformerEncoderAdapter(    FFTransformerDecoderAdapter, FFTransformerEncoder, adapter_mixins.AdapterModuleMixin):    """ Inherit from FFTransformerEncoder and add support for adapter"""
    pass

class AlignmentEncoderAdapter(AlignmentEncoder, adapter_mixins.AdapterModuleMixin):    """ Inherit from AlignmentEncoder and add support for adapter"""
    def add_adapter(self, name: str, cfg: dict):
        for i, conv_layer in enumerate(self.key_proj):            if i % 2 == 0:                cfg = self._update_adapter_cfg_input_dim(cfg, conv_layer.conv.out_channels)                conv_layer.add_adapter(name, cfg)
        for i, conv_layer in enumerate(self.query_proj):            if i % 2 == 0:                cfg = self._update_adapter_cfg_input_dim(cfg, conv_layer.conv.out_channels)                conv_layer.add_adapter(name, cfg)
    def is_adapter_available(self) -> bool:        return any(            [conv_layer.is_adapter_available() for i, conv_layer in enumerate(self.key_proj) if i % 2 == 0]            + [conv_layer.is_adapter_available() for i, conv_layer in enumerate(self.query_proj) if i % 2 == 0]        )
    def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):        for i, conv_layer in enumerate(self.key_proj):            if i % 2 == 0:                conv_layer.set_enabled_adapters(name=name, enabled=enabled)        for i, conv_layer in enumerate(self.query_proj):            if i % 2 == 0:                conv_layer.set_enabled_adapters(name=name, enabled=enabled)
    def get_enabled_adapters(self) -> List[str]:        names = set([])        for i, conv_layer in enumerate(self.key_proj):            if i % 2 == 0:                names.update(conv_layer.get_enabled_adapters())        for i, conv_layer in enumerate(self.query_proj):            if i % 2 == 0:                names.update(conv_layer.get_enabled_adapters())
        names = sorted(list(names))        return names
    def _update_adapter_cfg_input_dim(self, cfg: DictConfig, module_dim: int):        cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=module_dim)        return cfg

class TemporalPredictorAdapter(TemporalPredictor, adapter_mixins.AdapterModuleMixin):    """ Inherit from TemporalPredictor and add support for adapter"""
    def add_adapter(self, name: str, cfg: dict):        cfg = self._update_adapter_cfg_input_dim(cfg)        for conv_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin            conv_layer.add_adapter(name, cfg)
    def is_adapter_available(self) -> bool:        return any([conv_layer.is_adapter_available() for conv_layer in self.layers])
    def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):        for conv_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin            conv_layer.set_enabled_adapters(name=name, enabled=enabled)
    def get_enabled_adapters(self) -> List[str]:        names = set([])        for conv_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin            names.update(conv_layer.get_enabled_adapters())
        names = sorted(list(names))        return names
    def _update_adapter_cfg_input_dim(self, cfg: DictConfig):        cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.filter_size)        return cfg

"""Register any additional information"""if adapter_mixins.get_registered_adapter(FFTransformerEncoder) is None:    adapter_mixins.register_adapter(base_class=FFTransformerEncoder, adapter_class=FFTransformerEncoderAdapter
if adapter_mixins.get_registered_adapter(FFTransformerDecoder) is None:    adapter_mixins.register_adapter(base_class=FFTransformerDecoder, adapter_class=FFTransformerDecoderAdapter)
if adapter_mixins.get_registered_adapter(AlignmentEncoder) is None:    adapter_mixins.register_adapter(base_class=AlignmentEncoder, adapter_class=AlignmentEncoderAdapter)
if adapter_mixins.get_registered_adapter(TemporalPredictor) is None:    adapter_mixins.register_adapter(base_class=TemporalPredictor, adapter_class=TemporalPredictorAdapter)

Guess you like

Origin blog.csdn.net/weixin_48827824/article/details/131655980