当前Convention未适配3.10([T]泛型注解导致的问题)
This commit is contained in:
0
cosyvoice/utils/__init__.py
Normal file
0
cosyvoice/utils/__init__.py
Normal file
83
cosyvoice/utils/class_utils.py
Normal file
83
cosyvoice/utils/class_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from cosyvoice.transformer.activation import Swish
|
||||
from cosyvoice.transformer.subsampling import (
|
||||
LinearNoSubsampling,
|
||||
EmbedinigNoSubsampling,
|
||||
Conv1dSubsampling2,
|
||||
Conv2dSubsampling4,
|
||||
Conv2dSubsampling6,
|
||||
Conv2dSubsampling8,
|
||||
)
|
||||
from cosyvoice.transformer.embedding import (PositionalEncoding,
|
||||
RelPositionalEncoding,
|
||||
WhisperPositionalEncoding,
|
||||
LearnablePositionalEncoding,
|
||||
NoPositionalEncoding)
|
||||
from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
||||
RelPositionMultiHeadedAttention)
|
||||
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
||||
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
||||
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
|
||||
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
|
||||
from cosyvoice.hifigan.generator import HiFTGenerator
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
|
||||
|
||||
COSYVOICE_ACTIVATION_CLASSES = {
|
||||
"hardtanh": torch.nn.Hardtanh,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"relu": torch.nn.ReLU,
|
||||
"selu": torch.nn.SELU,
|
||||
"swish": getattr(torch.nn, "SiLU", Swish),
|
||||
"gelu": torch.nn.GELU,
|
||||
}
|
||||
|
||||
COSYVOICE_SUBSAMPLE_CLASSES = {
|
||||
"linear": LinearNoSubsampling,
|
||||
"linear_legacy": LegacyLinearNoSubsampling,
|
||||
"embed": EmbedinigNoSubsampling,
|
||||
"conv1d2": Conv1dSubsampling2,
|
||||
"conv2d": Conv2dSubsampling4,
|
||||
"conv2d6": Conv2dSubsampling6,
|
||||
"conv2d8": Conv2dSubsampling8,
|
||||
'paraformer_dummy': torch.nn.Identity
|
||||
}
|
||||
|
||||
COSYVOICE_EMB_CLASSES = {
|
||||
"embed": PositionalEncoding,
|
||||
"abs_pos": PositionalEncoding,
|
||||
"rel_pos": RelPositionalEncoding,
|
||||
"rel_pos_espnet": EspnetRelPositionalEncoding,
|
||||
"no_pos": NoPositionalEncoding,
|
||||
"abs_pos_whisper": WhisperPositionalEncoding,
|
||||
"embed_learnable_pe": LearnablePositionalEncoding,
|
||||
}
|
||||
|
||||
COSYVOICE_ATTENTION_CLASSES = {
|
||||
"selfattn": MultiHeadedAttention,
|
||||
"rel_selfattn": RelPositionMultiHeadedAttention,
|
||||
}
|
||||
|
||||
|
||||
def get_model_type(configs):
|
||||
# NOTE CosyVoice2Model inherits CosyVoiceModel
|
||||
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoiceModel
|
||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoice2Model
|
||||
raise TypeError('No valid model type found!')
|
||||
186
cosyvoice/utils/common.py
Normal file
186
cosyvoice/utils/common.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Unility functions for Transformer."""
|
||||
|
||||
import queue
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
IGNORE_ID = -1
|
||||
|
||||
|
||||
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
max_len = max([len(item) for item in xs])
|
||||
batchs = len(xs)
|
||||
ndim = xs[0].ndim
|
||||
if ndim == 1:
|
||||
pad_res = torch.zeros(batchs,
|
||||
max_len,
|
||||
dtype=xs[0].dtype,
|
||||
device=xs[0].device)
|
||||
elif ndim == 2:
|
||||
pad_res = torch.zeros(batchs,
|
||||
max_len,
|
||||
xs[0].shape[1],
|
||||
dtype=xs[0].dtype,
|
||||
device=xs[0].device)
|
||||
elif ndim == 3:
|
||||
pad_res = torch.zeros(batchs,
|
||||
max_len,
|
||||
xs[0].shape[1],
|
||||
xs[0].shape[2],
|
||||
dtype=xs[0].dtype,
|
||||
device=xs[0].device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ndim: {ndim}")
|
||||
pad_res.fill_(pad_value)
|
||||
for i in range(batchs):
|
||||
pad_res[i, :len(xs[i])] = xs[i]
|
||||
return pad_res
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
|
||||
ignore_label: int) -> torch.Tensor:
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(
|
||||
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return (numerator / denominator).detach()
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
# Repetition Aware Sampling in VALL-E 2
|
||||
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
||||
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
|
||||
if rep_num >= win_size * tau_r:
|
||||
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
||||
return top_ids
|
||||
|
||||
|
||||
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||
prob, indices = [], []
|
||||
cum_prob = 0.0
|
||||
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
||||
for i in range(len(sorted_idx)):
|
||||
# sampling both top-p and numbers.
|
||||
if cum_prob < top_p and len(prob) < top_k:
|
||||
cum_prob += sorted_value[i]
|
||||
prob.append(sorted_value[i])
|
||||
indices.append(sorted_idx[i])
|
||||
else:
|
||||
break
|
||||
prob = torch.tensor(prob).to(weighted_scores)
|
||||
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
||||
top_ids = indices[prob.multinomial(1, replacement=True)]
|
||||
return top_ids
|
||||
|
||||
|
||||
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
||||
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
||||
return top_ids
|
||||
|
||||
|
||||
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
||||
device = fade_in_mel.device
|
||||
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
||||
mel_overlap_len = int(window.shape[0] / 2)
|
||||
if fade_in_mel.device == torch.device('cpu'):
|
||||
fade_in_mel = fade_in_mel.clone()
|
||||
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
||||
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
return fade_in_mel.to(device)
|
||||
|
||||
|
||||
def set_all_random_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
assert mask.dtype == torch.bool
|
||||
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
||||
mask = mask.to(dtype)
|
||||
# attention mask bias
|
||||
# NOTE(Mddct): torch.finfo jit issues
|
||||
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
||||
mask = (1.0 - mask) * -1.0e+10
|
||||
return mask
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
176
cosyvoice/utils/executor.py
Normal file
176
cosyvoice/utils/executor.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
|
||||
|
||||
|
||||
class Executor:
|
||||
|
||||
def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
|
||||
self.gan = gan
|
||||
self.ref_model = ref_model
|
||||
self.dpo_loss = dpo_loss
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.rank = int(os.environ.get('RANK', 0))
|
||||
self.device = torch.device('cuda:{}'.format(self.rank))
|
||||
|
||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
||||
logging.info('using accumulate grad, new batch size is {} times'
|
||||
' larger than before'.format(info_dict['accum_grad']))
|
||||
# A context manager to be used in conjunction with an instance of
|
||||
# torch.nn.parallel.DistributedDataParallel to be able to train
|
||||
# with uneven inputs across participating processes.
|
||||
model.train()
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
||||
with model_context():
|
||||
for batch_idx, batch_dict in enumerate(train_data_loader):
|
||||
info_dict["tag"] = "TRAIN"
|
||||
info_dict["step"] = self.step
|
||||
info_dict["epoch"] = self.epoch
|
||||
info_dict["batch_idx"] = batch_idx
|
||||
if cosyvoice_join(group_join, info_dict):
|
||||
break
|
||||
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
||||
context = model.no_sync
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
else:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||
log_per_step(writer, info_dict)
|
||||
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
||||
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
||||
model.train()
|
||||
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
self.step += 1
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, scaler, group_join):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
||||
logging.info('using accumulate grad, new batch size is {} times'
|
||||
' larger than before'.format(info_dict['accum_grad']))
|
||||
# A context manager to be used in conjunction with an instance of
|
||||
# torch.nn.parallel.DistributedDataParallel to be able to train
|
||||
# with uneven inputs across participating processes.
|
||||
model.train()
|
||||
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
||||
with model_context():
|
||||
for batch_idx, batch_dict in enumerate(train_data_loader):
|
||||
info_dict["tag"] = "TRAIN"
|
||||
info_dict["step"] = self.step
|
||||
info_dict["epoch"] = self.epoch
|
||||
info_dict["batch_idx"] = batch_idx
|
||||
if cosyvoice_join(group_join, info_dict):
|
||||
break
|
||||
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
||||
context = model.no_sync
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
else:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
batch_dict['turn'] = 'discriminator'
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
|
||||
optimizer.zero_grad()
|
||||
log_per_step(writer, info_dict)
|
||||
with context():
|
||||
batch_dict['turn'] = 'generator'
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||
optimizer_d.zero_grad()
|
||||
log_per_step(writer, info_dict)
|
||||
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
||||
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
||||
model.train()
|
||||
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
self.step += 1
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
@torch.inference_mode()
|
||||
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
|
||||
''' Cross validation on
|
||||
'''
|
||||
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
|
||||
model.eval()
|
||||
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
|
||||
for batch_idx, batch_dict in enumerate(cv_data_loader):
|
||||
info_dict["tag"] = "CV"
|
||||
info_dict["step"] = self.step
|
||||
info_dict["epoch"] = self.epoch
|
||||
info_dict["batch_idx"] = batch_idx
|
||||
|
||||
num_utts = len(batch_dict["utts"])
|
||||
total_num_utts += num_utts
|
||||
|
||||
if self.gan is True:
|
||||
batch_dict['turn'] = 'generator'
|
||||
info_dict = batch_forward(model, batch_dict, None, info_dict)
|
||||
|
||||
for k, v in info_dict['loss_dict'].items():
|
||||
if k not in total_loss_dict:
|
||||
total_loss_dict[k] = []
|
||||
total_loss_dict[k].append(v.mean().item() * num_utts)
|
||||
log_per_step(None, info_dict)
|
||||
for k, v in total_loss_dict.items():
|
||||
total_loss_dict[k] = sum(v) / total_num_utts
|
||||
info_dict['loss_dict'] = total_loss_dict
|
||||
log_per_save(writer, info_dict)
|
||||
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
|
||||
save_model(model, model_name, info_dict)
|
||||
129
cosyvoice/utils/file_utils.py
Normal file
129
cosyvoice/utils/file_utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import torchaudio
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
lists.append(line.strip())
|
||||
return lists
|
||||
|
||||
|
||||
def read_json_lists(list_file):
|
||||
lists = read_lists(list_file)
|
||||
results = {}
|
||||
for fn in lists:
|
||||
with open(fn, 'r', encoding='utf8') as fin:
|
||||
results.update(json.load(fin))
|
||||
return results
|
||||
|
||||
|
||||
def load_wav(wav, target_sr):
|
||||
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
|
||||
speech = speech.mean(dim=0, keepdim=True)
|
||||
if sample_rate != target_sr:
|
||||
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
||||
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
||||
return speech
|
||||
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
import tensorrt as trt
|
||||
logging.info("Converting onnx to trt...")
|
||||
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
# load onnx model
|
||||
with open(onnx_model, "rb") as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||
# set input shapes
|
||||
for i in range(len(trt_kwargs['input_names'])):
|
||||
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
||||
# set input and output data type
|
||||
for i in range(network.num_inputs):
|
||||
input_tensor = network.get_input(i)
|
||||
input_tensor.dtype = tensor_dtype
|
||||
for i in range(network.num_outputs):
|
||||
output_tensor = network.get_output(i)
|
||||
output_tensor.dtype = tensor_dtype
|
||||
config.add_optimization_profile(profile)
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
# save trt engine
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
def export_cosyvoice2_vllm(model, model_path, device):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
vocab_size = model.speech_embedding.num_embeddings
|
||||
feature_size = model.speech_embedding.embedding_dim
|
||||
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
dtype = torch.bfloat16
|
||||
# lm_head
|
||||
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
||||
with torch.no_grad():
|
||||
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
|
||||
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
|
||||
new_lm_head.weight[vocab_size:] = 0
|
||||
new_lm_head.bias[vocab_size:] = 0
|
||||
model.llm.model.lm_head = new_lm_head
|
||||
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||
# embed_tokens
|
||||
embed_tokens = model.llm.model.model.embed_tokens
|
||||
with torch.no_grad():
|
||||
new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
|
||||
new_codec_embed.weight[vocab_size:] = 0
|
||||
model.llm.model.set_input_embeddings(new_codec_embed)
|
||||
model.llm.model.to(device)
|
||||
model.llm.model.to(dtype)
|
||||
tmp_vocab_size = model.llm.model.config.vocab_size
|
||||
tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
|
||||
del model.llm.model.generation_config.eos_token_id
|
||||
del model.llm.model.config.bos_token_id
|
||||
del model.llm.model.config.eos_token_id
|
||||
model.llm.model.config.vocab_size = pad_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = False
|
||||
model.llm.model.config.use_bias = True
|
||||
model.llm.model.save_pretrained(model_path)
|
||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
model.llm.model.set_input_embeddings(embed_tokens)
|
||||
136
cosyvoice/utils/frontend_utils.py
Normal file
136
cosyvoice/utils/frontend_utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import regex
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||
|
||||
|
||||
# whether contain chinese character
|
||||
def contains_chinese(text):
|
||||
return bool(chinese_char_pattern.search(text))
|
||||
|
||||
|
||||
# replace special symbol
|
||||
def replace_corner_mark(text):
|
||||
text = text.replace('²', '平方')
|
||||
text = text.replace('³', '立方')
|
||||
return text
|
||||
|
||||
|
||||
# remove meaningless symbol
|
||||
def remove_bracket(text):
|
||||
text = text.replace('(', '').replace(')', '')
|
||||
text = text.replace('【', '').replace('】', '')
|
||||
text = text.replace('`', '').replace('`', '')
|
||||
text = text.replace("——", " ")
|
||||
return text
|
||||
|
||||
|
||||
# spell Arabic numerals
|
||||
def spell_out_number(text: str, inflect_parser):
|
||||
new_text = []
|
||||
st = None
|
||||
for i, c in enumerate(text):
|
||||
if not c.isdigit():
|
||||
if st is not None:
|
||||
num_str = inflect_parser.number_to_words(text[st: i])
|
||||
new_text.append(num_str)
|
||||
st = None
|
||||
new_text.append(c)
|
||||
else:
|
||||
if st is None:
|
||||
st = i
|
||||
if st is not None and st < len(text):
|
||||
num_str = inflect_parser.number_to_words(text[st:])
|
||||
new_text.append(num_str)
|
||||
return ''.join(new_text)
|
||||
|
||||
|
||||
# split paragrah logic:
|
||||
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
||||
# 2. cal sentence len according to lang
|
||||
# 3. split sentence according to puncatation
|
||||
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
||||
def calc_utt_length(_text: str):
|
||||
if lang == "zh":
|
||||
return len(_text)
|
||||
else:
|
||||
return len(tokenize(_text))
|
||||
|
||||
def should_merge(_text: str):
|
||||
if lang == "zh":
|
||||
return len(_text) < merge_len
|
||||
else:
|
||||
return len(tokenize(_text)) < merge_len
|
||||
|
||||
if lang == "zh":
|
||||
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
||||
else:
|
||||
pounc = ['.', '?', '!', ';', ':']
|
||||
if comma_split:
|
||||
pounc.extend([',', ','])
|
||||
|
||||
if text[-1] not in pounc:
|
||||
if lang == "zh":
|
||||
text += "。"
|
||||
else:
|
||||
text += "."
|
||||
|
||||
st = 0
|
||||
utts = []
|
||||
for i, c in enumerate(text):
|
||||
if c in pounc:
|
||||
if len(text[st: i]) > 0:
|
||||
utts.append(text[st: i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
||||
tmp = utts.pop(-1)
|
||||
utts.append(tmp + text[i + 1])
|
||||
st = i + 2
|
||||
else:
|
||||
st = i + 1
|
||||
|
||||
final_utts = []
|
||||
cur_utt = ""
|
||||
for utt in utts:
|
||||
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
||||
final_utts.append(cur_utt)
|
||||
cur_utt = ""
|
||||
cur_utt = cur_utt + utt
|
||||
if len(cur_utt) > 0:
|
||||
if should_merge(cur_utt) and len(final_utts) != 0:
|
||||
final_utts[-1] = final_utts[-1] + cur_utt
|
||||
else:
|
||||
final_utts.append(cur_utt)
|
||||
|
||||
return final_utts
|
||||
|
||||
|
||||
# remove blank between chinese character
|
||||
def replace_blank(text: str):
|
||||
out_str = []
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
||||
(text[i - 1].isascii() and text[i - 1] != " ")):
|
||||
out_str.append(c)
|
||||
else:
|
||||
out_str.append(c)
|
||||
return "".join(out_str)
|
||||
|
||||
|
||||
def is_only_punctuation(text):
|
||||
# Regular expression: Match strings that consist only of punctuation marks or are empty.
|
||||
punctuation_pattern = r'^[\p{P}\p{S}]*$'
|
||||
return bool(regex.fullmatch(punctuation_pattern, text))
|
||||
57
cosyvoice/utils/losses.py
Normal file
57
cosyvoice/utils/losses.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||
loss = 0
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
m_DG = torch.median((dr - dg))
|
||||
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
|
||||
loss += tau - F.relu(tau - L_rel)
|
||||
return loss
|
||||
|
||||
|
||||
def mel_loss(real_speech, generated_speech, mel_transforms):
|
||||
loss = 0
|
||||
for transform in mel_transforms:
|
||||
mel_r = transform(real_speech)
|
||||
mel_g = transform(generated_speech)
|
||||
loss += F.l1_loss(mel_g, mel_r)
|
||||
return loss
|
||||
|
||||
|
||||
class DPOLoss(torch.nn.Module):
|
||||
"""
|
||||
DPO Loss
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.label_smoothing = label_smoothing
|
||||
self.ipo = ipo
|
||||
|
||||
def forward(
|
||||
self,
|
||||
policy_chosen_logps: torch.Tensor,
|
||||
policy_rejected_logps: torch.Tensor,
|
||||
reference_chosen_logps: torch.Tensor,
|
||||
reference_rejected_logps: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||
logits = pi_logratios - ref_logratios
|
||||
if self.ipo:
|
||||
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
|
||||
else:
|
||||
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
loss = losses.mean()
|
||||
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
||||
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
||||
|
||||
return loss, chosen_rewards, rejected_rewards
|
||||
265
cosyvoice/utils/mask.py
Normal file
265
cosyvoice/utils/mask.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright (c) 2019 Shigeki Karita
|
||||
# 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
'''
|
||||
def subsequent_mask(
|
||||
size: int,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
This mask is used only in decoder which works in an auto-regressive mode.
|
||||
This means the current step could only do attention with its left steps.
|
||||
|
||||
In encoder, fully attention is used when streaming is not necessary and
|
||||
the sequence is not long. In this case, no attention mask is needed.
|
||||
|
||||
When streaming is need, chunk-based attention is used in encoder. See
|
||||
subsequent_chunk_mask for the chunk-based attention mask.
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
||||
dtype (torch.device): result dtype
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
||||
return torch.tril(ret)
|
||||
'''
|
||||
|
||||
|
||||
def subsequent_mask(
|
||||
size: int,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
This mask is used only in decoder which works in an auto-regressive mode.
|
||||
This means the current step could only do attention with its left steps.
|
||||
|
||||
In encoder, fully attention is used when streaming is not necessary and
|
||||
the sequence is not long. In this case, no attention mask is needed.
|
||||
|
||||
When streaming is need, chunk-based attention is used in encoder. See
|
||||
subsequent_chunk_mask for the chunk-based attention mask.
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
||||
dtype (torch.device): result dtype
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
arange = torch.arange(size, device=device)
|
||||
mask = arange.expand(size, size)
|
||||
arange = arange.unsqueeze(-1)
|
||||
mask = mask <= arange
|
||||
return mask
|
||||
|
||||
|
||||
def subsequent_chunk_mask_deprecated(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
for i in range(size):
|
||||
if num_left_chunks < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||||
ending = min((i // chunk_size + 1) * chunk_size, size)
|
||||
ret[i, start:ending] = True
|
||||
return ret
|
||||
|
||||
|
||||
def subsequent_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
||||
pos_idx = torch.arange(size, device=device)
|
||||
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
||||
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
||||
return ret
|
||||
|
||||
|
||||
def add_optional_chunk_mask(xs: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
use_dynamic_chunk: bool,
|
||||
use_dynamic_left_chunk: bool,
|
||||
decoding_chunk_size: int,
|
||||
static_chunk_size: int,
|
||||
num_decoding_left_chunks: int,
|
||||
enable_full_context: bool = True):
|
||||
""" Apply optional mask for encoder.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
||||
mask (torch.Tensor): mask for xs, (B, 1, L)
|
||||
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
||||
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
||||
training.
|
||||
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
static_chunk_size (int): chunk size for static chunk training/decoding
|
||||
if it's greater than 0, if use_dynamic_chunk is true,
|
||||
this parameter will be ignored
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
enable_full_context (bool):
|
||||
True: chunk size is either [1, 25] or full context(max_len)
|
||||
False: chunk size ~ U[1, 25]
|
||||
|
||||
Returns:
|
||||
torch.Tensor: chunk mask of the input xs.
|
||||
"""
|
||||
# Whether to use chunk mask or not
|
||||
if use_dynamic_chunk:
|
||||
max_len = xs.size(1)
|
||||
if decoding_chunk_size < 0:
|
||||
chunk_size = max_len
|
||||
num_left_chunks = -1
|
||||
elif decoding_chunk_size > 0:
|
||||
chunk_size = decoding_chunk_size
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
else:
|
||||
# chunk size is either [1, 25] or full context(max_len).
|
||||
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
||||
# delay, the maximum frame is 100 / 4 = 25.
|
||||
chunk_size = torch.randint(1, max_len, (1, )).item()
|
||||
num_left_chunks = -1
|
||||
if chunk_size > max_len // 2 and enable_full_context:
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % 25 + 1
|
||||
if use_dynamic_left_chunk:
|
||||
max_left_chunks = (max_len - 1) // chunk_size
|
||||
num_left_chunks = torch.randint(0, max_left_chunks,
|
||||
(1, )).item()
|
||||
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
||||
num_left_chunks,
|
||||
xs.device) # (L, L)
|
||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
elif static_chunk_size > 0:
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
||||
num_left_chunks,
|
||||
xs.device) # (L, L)
|
||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
else:
|
||||
chunk_masks = masks
|
||||
assert chunk_masks.dtype == torch.bool
|
||||
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
||||
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
||||
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
||||
return chunk_masks
|
||||
|
||||
|
||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
See description of make_non_pad_mask.
|
||||
|
||||
Args:
|
||||
lengths (torch.Tensor): Batch of lengths (B,).
|
||||
Returns:
|
||||
torch.Tensor: Mask tensor containing indices of padded part.
|
||||
|
||||
Examples:
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
"""
|
||||
batch_size = lengths.size(0)
|
||||
max_len = max_len if max_len > 0 else lengths.max().item()
|
||||
seq_range = torch.arange(0,
|
||||
max_len,
|
||||
dtype=torch.int64,
|
||||
device=lengths.device)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
738
cosyvoice/utils/scheduler.py
Normal file
738
cosyvoice/utils/scheduler.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2022 Ximalaya Inc (Yuguang Yang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
# NeMo(https://github.com/NVIDIA/NeMo)
|
||||
|
||||
from typing import Union
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class WarmupLR(_LRScheduler):
|
||||
"""The WarmupLR scheduler
|
||||
|
||||
This scheduler is almost same as NoamLR Scheduler except for following
|
||||
difference:
|
||||
|
||||
NoamLR:
|
||||
lr = optimizer.lr * model_size ** -0.5
|
||||
* min(step ** -0.5, step * warmup_step ** -1.5)
|
||||
WarmupLR:
|
||||
lr = optimizer.lr * warmup_step ** 0.5
|
||||
* min(step ** -0.5, step * warmup_step ** -1.5)
|
||||
|
||||
Note that the maximum lr equals to optimizer.lr in this scheduler.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
warmup_steps: Union[int, float] = 25000,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
# __init__() must be invoked before setting field
|
||||
# because step() is also invoked in __init__()
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
|
||||
|
||||
def get_lr(self):
|
||||
step_num = self.last_epoch + 1
|
||||
if self.warmup_steps == 0:
|
||||
return [lr * step_num**-0.5 for lr in self.base_lrs]
|
||||
else:
|
||||
return [
|
||||
lr * self.warmup_steps**0.5 *
|
||||
min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
|
||||
for lr in self.base_lrs
|
||||
]
|
||||
|
||||
def set_step(self, step: int):
|
||||
self.last_epoch = step
|
||||
|
||||
|
||||
class WarmupPolicy(_LRScheduler):
|
||||
"""Adds warmup kwargs and warmup logic to lr policy.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1):
|
||||
assert not (warmup_steps is not None and warmup_ratio is not None),\
|
||||
"Either use particular number of step or ratio"
|
||||
assert warmup_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed "
|
||||
"by the scheduler, please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
if step <= self.warmup_steps and self.warmup_steps > 0:
|
||||
return self._get_warmup_lr(step)
|
||||
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
def _get_warmup_lr(self, step):
|
||||
lr_val = (step + 1) / (self.warmup_steps + 1)
|
||||
return [initial_lr * lr_val for initial_lr in self.base_lrs]
|
||||
|
||||
def _get_lr(self, step):
|
||||
"""Simple const lr policy"""
|
||||
return self.base_lrs
|
||||
|
||||
|
||||
class SquareRootConstantPolicy(_LRScheduler):
|
||||
"""Adds warmup kwargs and warmup logic to lr policy.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
constant_steps=None,
|
||||
constant_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1):
|
||||
assert not (constant_steps is not None
|
||||
and constant_ratio is not None), \
|
||||
"Either use particular number of step or ratio"
|
||||
assert constant_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
if constant_steps is not None:
|
||||
self.constant_steps = constant_steps
|
||||
elif constant_ratio is not None:
|
||||
self.constant_steps = int(constant_ratio * max_steps)
|
||||
else:
|
||||
self.constant_steps = 0
|
||||
|
||||
self.constant_lr = 1 / (constant_steps**0.5)
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed "
|
||||
"by the scheduler, please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
if step <= self.constant_steps:
|
||||
return [self.constant_lr for _ in self.base_lrs]
|
||||
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
def _get_lr(self, step):
|
||||
"""Simple const lr policy"""
|
||||
return self.base_lrs
|
||||
|
||||
|
||||
class WarmupHoldPolicy(WarmupPolicy):
|
||||
"""Variant of WarmupPolicy which maintains high
|
||||
learning rate for a defined number of steps.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
hold_steps: Number of training steps to
|
||||
hold the learning rate after warm up
|
||||
hold_ratio: Ratio of hold steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
*,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
hold_steps=None,
|
||||
hold_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
):
|
||||
assert not (hold_steps is not None and hold_ratio is not None), \
|
||||
"Either use particular number of step or ratio"
|
||||
assert hold_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
self.min_lr = min_lr
|
||||
self._last_warmup_lr = 0.0
|
||||
|
||||
# Necessary to duplicate as class attributes are hidden in inner class
|
||||
self.max_steps = max_steps
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
if hold_steps is not None:
|
||||
self.hold_steps = hold_steps + self.warmup_steps
|
||||
elif hold_ratio is not None:
|
||||
self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
|
||||
else:
|
||||
self.hold_steps = 0
|
||||
|
||||
super().__init__(
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
warmup_ratio=warmup_ratio,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed by the scheduler,"
|
||||
" "
|
||||
"please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
# Warmup phase
|
||||
if step <= self.warmup_steps and self.warmup_steps > 0:
|
||||
return self._get_warmup_lr(step)
|
||||
|
||||
# Hold phase
|
||||
if (step >= self.warmup_steps) and (step < self.hold_steps):
|
||||
return self.base_lrs
|
||||
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
|
||||
class WarmupAnnealHoldPolicy(_LRScheduler):
|
||||
"""Adds warmup kwargs and warmup logic to lr policy.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
min_lr: Minimum lr to hold the learning rate after decay at.
|
||||
constant_steps: Number of steps to keep lr constant at.
|
||||
constant_ratio: Ratio of steps to keep lr constant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
*,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
constant_steps=None,
|
||||
constant_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
):
|
||||
assert not (warmup_steps is not None
|
||||
and warmup_ratio is not None), \
|
||||
"Either use particular number of step or ratio"
|
||||
assert not (constant_steps is not None
|
||||
and constant_ratio is not None), \
|
||||
"Either use constant_steps or constant_ratio"
|
||||
assert warmup_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
if constant_steps is not None:
|
||||
self.constant_steps = constant_steps
|
||||
elif constant_ratio is not None:
|
||||
self.constant_steps = int(constant_ratio * max_steps)
|
||||
else:
|
||||
self.constant_steps = 0
|
||||
|
||||
self.decay_steps = max_steps - (self.constant_steps +
|
||||
self.warmup_steps)
|
||||
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed "
|
||||
"by the scheduler, please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
# Warmup steps
|
||||
if self.warmup_steps > 0 and step <= self.warmup_steps:
|
||||
return self._get_warmup_lr(step)
|
||||
|
||||
# Constant steps after warmup and decay
|
||||
if self.constant_steps > 0 and (
|
||||
self.warmup_steps + self.decay_steps) < step <= self.max_steps:
|
||||
return self._get_constant_lr(step)
|
||||
|
||||
# Min lr after max steps of updates
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
def _get_warmup_lr(self, step):
|
||||
lr_val = (step + 1) / (self.warmup_steps + 1)
|
||||
return [initial_lr * lr_val for initial_lr in self.base_lrs]
|
||||
|
||||
def _get_constant_lr(self, step):
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
def _get_lr(self, step):
|
||||
"""Simple const lr policy"""
|
||||
return self.base_lrs
|
||||
|
||||
|
||||
def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
|
||||
mult = ((max_steps - step) / max_steps)**0.5
|
||||
out_lr = initial_lr * mult
|
||||
out_lr = max(out_lr, min_lr)
|
||||
return out_lr
|
||||
|
||||
|
||||
def _square_annealing(initial_lr, step, max_steps, min_lr):
|
||||
mult = ((max_steps - step) / max_steps)**2
|
||||
out_lr = initial_lr * mult
|
||||
out_lr = max(out_lr, min_lr)
|
||||
return out_lr
|
||||
|
||||
|
||||
def _cosine_annealing(initial_lr, step, max_steps, min_lr):
|
||||
mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
|
||||
out_lr = (initial_lr - min_lr) * mult + min_lr
|
||||
return out_lr
|
||||
|
||||
|
||||
def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
|
||||
decay_steps, min_lr):
|
||||
assert max_lr > min_lr
|
||||
# Use linear warmup for the initial part.
|
||||
if warmup_steps > 0 and step <= warmup_steps:
|
||||
return max_lr * float(step) / float(warmup_steps)
|
||||
|
||||
# For any steps larger than `decay_steps`, use `min_lr`.
|
||||
if step > warmup_steps + decay_steps:
|
||||
return min_lr
|
||||
|
||||
# If we are done with the warmup period, use the decay style.
|
||||
num_steps_ = step - warmup_steps
|
||||
decay_steps_ = decay_steps
|
||||
decay_ratio = float(num_steps_) / float(decay_steps_)
|
||||
assert decay_ratio >= 0.0
|
||||
assert decay_ratio <= 1.0
|
||||
delta_lr = max_lr - min_lr
|
||||
|
||||
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
|
||||
|
||||
return min_lr + coeff * delta_lr
|
||||
|
||||
|
||||
def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
|
||||
if cycle:
|
||||
multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
|
||||
decay_steps *= multiplier
|
||||
else:
|
||||
step = min(step, decay_steps)
|
||||
p = step / decay_steps
|
||||
lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
|
||||
lr += min_lr
|
||||
return lr
|
||||
|
||||
|
||||
def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
|
||||
decay_rate, min_lr):
|
||||
# hold_steps = total number of steps
|
||||
# to hold the LR, not the warmup + hold steps.
|
||||
T_warmup_decay = max(1, warmup_steps**decay_rate)
|
||||
T_hold_decay = max(1, (step - hold_steps)**decay_rate)
|
||||
lr = (initial_lr * T_warmup_decay) / T_hold_decay
|
||||
lr = max(lr, min_lr)
|
||||
return lr
|
||||
|
||||
|
||||
class SquareAnnealing(WarmupPolicy):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
max_steps,
|
||||
min_lr=1e-5,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super().__init__(optimizer=optimizer,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
**kwargs)
|
||||
|
||||
def _get_lr(self, step):
|
||||
new_lrs = [
|
||||
_square_annealing(
|
||||
initial_lr=initial_lr,
|
||||
step=step - self.warmup_steps,
|
||||
max_steps=self.max_steps - self.warmup_steps,
|
||||
min_lr=self.min_lr,
|
||||
) for initial_lr in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
|
||||
class SquareRootAnnealing(WarmupPolicy):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
max_steps,
|
||||
min_lr=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super().__init__(optimizer=optimizer,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
**kwargs)
|
||||
|
||||
def _get_lr(self, step):
|
||||
new_lrs = [
|
||||
_squareroot_annealing(initial_lr=initial_lr,
|
||||
step=step,
|
||||
max_steps=self.max_steps,
|
||||
min_lr=self.min_lr)
|
||||
for initial_lr in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
|
||||
class CosineAnnealing(WarmupAnnealHoldPolicy):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
max_steps,
|
||||
min_lr=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super().__init__(optimizer=optimizer,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
**kwargs)
|
||||
|
||||
def _get_lr(self, step):
|
||||
for initial_lr in self.base_lrs:
|
||||
if initial_lr < self.min_lr:
|
||||
raise ValueError(
|
||||
f"{self} received an initial learning rate "
|
||||
f"that was lower than the minimum learning rate.")
|
||||
|
||||
if self.constant_steps is None or self.constant_steps == 0:
|
||||
new_lrs = [
|
||||
_cosine_annealing(
|
||||
initial_lr=initial_lr,
|
||||
step=step - self.warmup_steps,
|
||||
max_steps=self.max_steps - self.warmup_steps,
|
||||
min_lr=self.min_lr,
|
||||
) for initial_lr in self.base_lrs
|
||||
]
|
||||
else:
|
||||
new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
|
||||
return new_lrs
|
||||
|
||||
def _get_warmup_lr(self, step):
|
||||
if self.constant_steps is None or self.constant_steps == 0:
|
||||
return super()._get_warmup_lr(step)
|
||||
else:
|
||||
# Use linear warmup for the initial part.
|
||||
return self._get_linear_warmup_with_cosine_annealing_lr(step)
|
||||
|
||||
def _get_constant_lr(self, step):
|
||||
# Only called when `constant_steps` > 0.
|
||||
return self._get_linear_warmup_with_cosine_annealing_lr(step)
|
||||
|
||||
def _get_linear_warmup_with_cosine_annealing_lr(self, step):
|
||||
# Cosine Schedule for Megatron LM,
|
||||
# slightly different warmup schedule + constant LR at the end.
|
||||
new_lrs = [
|
||||
_linear_warmup_with_cosine_annealing(
|
||||
max_lr=self.base_lrs[0],
|
||||
warmup_steps=self.warmup_steps,
|
||||
step=step,
|
||||
decay_steps=self.decay_steps,
|
||||
min_lr=self.min_lr,
|
||||
) for _ in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
|
||||
class NoamAnnealing(_LRScheduler):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
d_model,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1):
|
||||
self._normalize = d_model**(-0.5)
|
||||
assert not (warmup_steps is not None and warmup_ratio is not None), \
|
||||
"Either use particular number of step or ratio"
|
||||
assert warmup_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed "
|
||||
"by the scheduler, please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = max(1, self.last_epoch)
|
||||
|
||||
for initial_lr in self.base_lrs:
|
||||
if initial_lr < self.min_lr:
|
||||
raise ValueError(
|
||||
f"{self} received an initial learning rate "
|
||||
f"that was lower than the minimum learning rate.")
|
||||
|
||||
new_lrs = [
|
||||
self._noam_annealing(initial_lr=initial_lr, step=step)
|
||||
for initial_lr in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
def _noam_annealing(self, initial_lr, step):
|
||||
if self.warmup_steps > 0:
|
||||
mult = self._normalize * min(step**(-0.5),
|
||||
step * (self.warmup_steps**(-1.5)))
|
||||
else:
|
||||
mult = self._normalize * step**(-0.5)
|
||||
|
||||
out_lr = initial_lr * mult
|
||||
if step > self.warmup_steps:
|
||||
out_lr = max(out_lr, self.min_lr)
|
||||
return out_lr
|
||||
|
||||
|
||||
class NoamHoldAnnealing(WarmupHoldPolicy):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
max_steps,
|
||||
decay_rate=0.5,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
"""
|
||||
From Nemo:
|
||||
Implementation of the Noam Hold Annealing policy
|
||||
from the SqueezeFormer paper.
|
||||
|
||||
Unlike NoamAnnealing, the peak learning rate
|
||||
can be explicitly set for this scheduler.
|
||||
The schedule first performs linear warmup,
|
||||
then holds the peak LR, then decays with some schedule for
|
||||
the remainder of the steps.
|
||||
Therefore the min-lr is still dependent
|
||||
on the hyper parameters selected.
|
||||
|
||||
It's schedule is determined by three factors-
|
||||
|
||||
Warmup Steps: Initial stage, where linear warmup
|
||||
occurs uptil the peak LR is reached. Unlike NoamAnnealing,
|
||||
the peak LR is explicitly stated here instead of a scaling factor.
|
||||
|
||||
Hold Steps: Intermediate stage, where the peak LR
|
||||
is maintained for some number of steps. In this region,
|
||||
the high peak LR allows the model to converge faster
|
||||
if training is stable. However the high LR
|
||||
may also cause instability during training.
|
||||
Should usually be a significant fraction of training
|
||||
steps (around 30-40% of the entire training steps).
|
||||
|
||||
Decay Steps: Final stage, where the LR rapidly decays
|
||||
with some scaling rate (set by decay rate).
|
||||
To attain Noam decay, use 0.5,
|
||||
for Squeezeformer recommended decay, use 1.0.
|
||||
The fast decay after prolonged high LR during
|
||||
hold phase allows for rapid convergence.
|
||||
|
||||
References:
|
||||
- [Squeezeformer:
|
||||
An Efficient Transformer for Automatic Speech Recognition]
|
||||
(https://arxiv.org/abs/2206.00888)
|
||||
|
||||
Args:
|
||||
optimizer: Pytorch compatible Optimizer object.
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
hold_steps: Number of training steps to
|
||||
hold the learning rate after warm up
|
||||
hold_ratio: Ratio of hold steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
decay_rate: Float value describing the polynomial decay
|
||||
after the hold period. Default value
|
||||
of 0.5 corresponds to Noam decay.
|
||||
min_lr: Minimum learning rate.
|
||||
"""
|
||||
self.decay_rate = decay_rate
|
||||
super().__init__(optimizer=optimizer,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
**kwargs)
|
||||
|
||||
def _get_lr(self, step):
|
||||
if self.warmup_steps is None or self.warmup_steps == 0:
|
||||
raise ValueError(
|
||||
"Noam scheduler cannot be used without warmup steps")
|
||||
|
||||
if self.hold_steps > 0:
|
||||
hold_steps = self.hold_steps - self.warmup_steps
|
||||
else:
|
||||
hold_steps = 0
|
||||
|
||||
new_lrs = [
|
||||
_noam_hold_annealing(
|
||||
initial_lr,
|
||||
step=step,
|
||||
warmup_steps=self.warmup_steps,
|
||||
hold_steps=hold_steps,
|
||||
decay_rate=self.decay_rate,
|
||||
min_lr=self.min_lr,
|
||||
) for initial_lr in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
def set_step(self, step: int):
|
||||
self.last_epoch = step
|
||||
|
||||
|
||||
class ConstantLR(_LRScheduler):
|
||||
"""The ConstantLR scheduler
|
||||
|
||||
This scheduler keeps a constant lr
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
):
|
||||
# __init__() must be invoked before setting field
|
||||
# because step() is also invoked in __init__()
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
return self.base_lrs
|
||||
|
||||
def set_step(self, step: int):
|
||||
self.last_epoch = step
|
||||
367
cosyvoice/utils/train_utils.py
Normal file
367
cosyvoice/utils/train_utils.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2023 Horizon Inc. (authors: Xingchen Song)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import re
|
||||
import datetime
|
||||
import yaml
|
||||
|
||||
import deepspeed
|
||||
import torch.optim as optim
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
|
||||
|
||||
from cosyvoice.dataset.dataset import Dataset
|
||||
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
|
||||
|
||||
|
||||
def init_distributed(args):
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
|
||||
', rank {}, world_size {}'.format(rank, world_size))
|
||||
if args.train_engine == 'torch_ddp':
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(args.dist_backend)
|
||||
else:
|
||||
deepspeed.init_distributed(dist_backend=args.dist_backend)
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def init_dataset_and_dataloader(args, configs, gan, dpo):
|
||||
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
||||
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
|
||||
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False)
|
||||
|
||||
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
||||
train_data_loader = DataLoader(train_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
cv_data_loader = DataLoader(cv_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
|
||||
|
||||
|
||||
def check_modify_and_save_config(args, configs):
|
||||
if args.train_engine == "torch_ddp":
|
||||
configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
|
||||
else:
|
||||
with open(args.deepspeed_config, 'r') as fin:
|
||||
ds_configs = json.load(fin)
|
||||
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
|
||||
configs['train_conf']["dtype"] = "fp16"
|
||||
elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
|
||||
configs['train_conf']["dtype"] = "bf16"
|
||||
else:
|
||||
configs['train_conf']["dtype"] = "fp32"
|
||||
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
|
||||
# if use deepspeed, override ddp config
|
||||
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
|
||||
configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
|
||||
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
|
||||
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
|
||||
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
|
||||
return configs
|
||||
|
||||
|
||||
def wrap_cuda_model(args, model):
|
||||
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
if args.train_engine == "torch_ddp": # native pytorch ddp
|
||||
assert (torch.cuda.is_available())
|
||||
model.cuda()
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
|
||||
else:
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logging.info("Estimating model states memory needs (zero2)...")
|
||||
estimate_zero2_model_states_mem_needs_all_live(
|
||||
model,
|
||||
num_gpus_per_node=local_world_size,
|
||||
num_nodes=world_size // local_world_size)
|
||||
return model
|
||||
|
||||
|
||||
def init_optimizer_and_scheduler(args, configs, model, gan):
|
||||
if gan is False:
|
||||
if configs['train_conf']['optim'] == 'adam':
|
||||
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
|
||||
elif configs['train_conf']['optim'] == 'adamw':
|
||||
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler = ConstantLR(optimizer)
|
||||
else:
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
|
||||
# use deepspeed optimizer for speedup
|
||||
if args.train_engine == "deepspeed":
|
||||
def scheduler(opt):
|
||||
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=args,
|
||||
model=model,
|
||||
optimizer=None,
|
||||
lr_scheduler=scheduler,
|
||||
model_parameters=model.parameters())
|
||||
|
||||
optimizer_d, scheduler_d = None, None
|
||||
|
||||
else:
|
||||
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
|
||||
if configs['train_conf']['optim'] == 'adam':
|
||||
optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
elif configs['train_conf']['optim'] == 'adamw':
|
||||
optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler = ConstantLR(optimizer)
|
||||
else:
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['optim_d'] == 'adam':
|
||||
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
elif configs['train_conf']['optim_d'] == 'adamw':
|
||||
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler_d = ConstantLR(optimizer_d)
|
||||
else:
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
return model, optimizer, scheduler, optimizer_d, scheduler_d
|
||||
|
||||
|
||||
def init_summarywriter(args):
|
||||
writer = None
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
return writer
|
||||
|
||||
|
||||
def save_model(model, model_name, info_dict):
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
model_dir = info_dict["model_dir"]
|
||||
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
|
||||
|
||||
if info_dict["train_engine"] == "torch_ddp":
|
||||
if rank == 0:
|
||||
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=model_dir,
|
||||
tag=model_name,
|
||||
client_state=info_dict)
|
||||
if rank == 0:
|
||||
info_path = re.sub('.pt$', '.yaml', save_model_path)
|
||||
info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
|
||||
with open(info_path, 'w') as fout:
|
||||
data = yaml.dump(info_dict)
|
||||
fout.write(data)
|
||||
logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
|
||||
|
||||
|
||||
def cosyvoice_join(group_join, info_dict):
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
|
||||
if info_dict["batch_idx"] != 0:
|
||||
# we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
|
||||
try:
|
||||
dist.monitored_barrier(group=group_join,
|
||||
timeout=group_join.options._timeout)
|
||||
return False
|
||||
except RuntimeError as e:
|
||||
logging.info("Detected uneven workload distribution: {}\n".format(e) +
|
||||
"Break current worker to manually join all workers, " +
|
||||
"world_size {}, current rank {}, current local_rank {}\n".
|
||||
format(world_size, rank, local_rank))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
|
||||
device = int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
dtype = info_dict["dtype"]
|
||||
if dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
else: # fp32
|
||||
dtype = torch.float32
|
||||
|
||||
if info_dict['train_engine'] == 'torch_ddp':
|
||||
autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
|
||||
else:
|
||||
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
|
||||
|
||||
with autocast:
|
||||
info_dict['loss_dict'] = model(batch, device)
|
||||
if ref_model is not None and dpo_loss is not None:
|
||||
chosen_logps = info_dict['loss_dict']["chosen_logps"]
|
||||
rejected_logps = info_dict['loss_dict']["rejected_logps"]
|
||||
sft_loss = info_dict['loss_dict']['loss']
|
||||
with torch.no_grad():
|
||||
ref_loss_dict = ref_model(batch, device)
|
||||
reference_chosen_logps = ref_loss_dict["chosen_logps"]
|
||||
reference_rejected_logps = ref_loss_dict["rejected_logps"]
|
||||
preference_loss, chosen_reward, reject_reward = dpo_loss(
|
||||
chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
dpo_acc = (chosen_reward > reject_reward).float().mean()
|
||||
info_dict['loss_dict']["loss"] = preference_loss + sft_loss
|
||||
info_dict['loss_dict']["sft_loss"] = sft_loss
|
||||
info_dict['loss_dict']["dpo_loss"] = preference_loss
|
||||
info_dict['loss_dict']["dpo_acc"] = dpo_acc
|
||||
info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
|
||||
info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
|
||||
return info_dict
|
||||
|
||||
|
||||
def batch_backward(model, scaler, info_dict):
|
||||
if info_dict["train_engine"] == "deepspeed":
|
||||
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
|
||||
else:
|
||||
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
|
||||
if scaler is not None:
|
||||
scaler.scale(scaled_loss).backward()
|
||||
else:
|
||||
scaled_loss.backward()
|
||||
|
||||
info_dict['loss_dict']['loss'] = scaled_loss
|
||||
return info_dict
|
||||
|
||||
|
||||
def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
|
||||
grad_norm = 0.0
|
||||
if info_dict['train_engine'] == "deepspeed":
|
||||
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
|
||||
model.step()
|
||||
grad_norm = model.get_global_grad_norm()
|
||||
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
|
||||
# Use mixed precision training
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
||||
# We don't check grad here since that if the gradient
|
||||
# has inf/nan values, scaler.step will skip
|
||||
# optimizer.step().
|
||||
if torch.isfinite(grad_norm):
|
||||
scaler.step(optimizer)
|
||||
else:
|
||||
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
||||
scaler.update()
|
||||
else:
|
||||
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
||||
if torch.isfinite(grad_norm):
|
||||
optimizer.step()
|
||||
else:
|
||||
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
info_dict["lr"] = optimizer.param_groups[0]['lr']
|
||||
info_dict["grad_norm"] = grad_norm
|
||||
return info_dict
|
||||
|
||||
|
||||
def log_per_step(writer, info_dict):
|
||||
tag = info_dict["tag"]
|
||||
epoch = info_dict.get('epoch', 0)
|
||||
step = info_dict["step"]
|
||||
batch_idx = info_dict["batch_idx"]
|
||||
loss_dict = info_dict['loss_dict']
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
|
||||
# only rank 0 write to tensorboard to avoid multi-process write
|
||||
if writer is not None:
|
||||
if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
|
||||
(info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
|
||||
for k in ['epoch', 'lr', 'grad_norm']:
|
||||
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
|
||||
for k, v in loss_dict.items():
|
||||
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
|
||||
|
||||
# TRAIN & CV, Shell log (stdout)
|
||||
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
|
||||
log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
|
||||
for name, value in loss_dict.items():
|
||||
log_str += '{} {:.6f} '.format(name, value)
|
||||
if tag == "TRAIN":
|
||||
log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
|
||||
info_dict["lr"], info_dict['grad_norm'])
|
||||
log_str += ' rank {}'.format(rank)
|
||||
logging.debug(log_str)
|
||||
|
||||
|
||||
def log_per_save(writer, info_dict):
|
||||
tag = info_dict["tag"]
|
||||
epoch = info_dict["epoch"]
|
||||
step = info_dict["step"]
|
||||
loss_dict = info_dict["loss_dict"]
|
||||
lr = info_dict['lr']
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
logging.info(
|
||||
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
|
||||
epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
|
||||
|
||||
if writer is not None:
|
||||
for k in ['epoch', 'lr']:
|
||||
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
|
||||
for k, v in loss_dict.items():
|
||||
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
|
||||
Reference in New Issue
Block a user