当前Convention未适配3.10([T]泛型注解导致的问题)
This commit is contained in:
57
cosyvoice/utils/losses.py
Normal file
57
cosyvoice/utils/losses.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||
loss = 0
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
m_DG = torch.median((dr - dg))
|
||||
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
|
||||
loss += tau - F.relu(tau - L_rel)
|
||||
return loss
|
||||
|
||||
|
||||
def mel_loss(real_speech, generated_speech, mel_transforms):
|
||||
loss = 0
|
||||
for transform in mel_transforms:
|
||||
mel_r = transform(real_speech)
|
||||
mel_g = transform(generated_speech)
|
||||
loss += F.l1_loss(mel_g, mel_r)
|
||||
return loss
|
||||
|
||||
|
||||
class DPOLoss(torch.nn.Module):
|
||||
"""
|
||||
DPO Loss
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.label_smoothing = label_smoothing
|
||||
self.ipo = ipo
|
||||
|
||||
def forward(
|
||||
self,
|
||||
policy_chosen_logps: torch.Tensor,
|
||||
policy_rejected_logps: torch.Tensor,
|
||||
reference_chosen_logps: torch.Tensor,
|
||||
reference_rejected_logps: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||
logits = pi_logratios - ref_logratios
|
||||
if self.ipo:
|
||||
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
|
||||
else:
|
||||
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
loss = losses.mean()
|
||||
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
||||
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
||||
|
||||
return loss, chosen_rewards, rejected_rewards
|
||||
Reference in New Issue
Block a user