Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 77 additions & 24 deletions src/chatterbox/models/s3gen/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import logging
import random
from typing import Dict, Optional
from types import SimpleNamespace
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from .utils.mask import make_pad_mask


Expand All @@ -33,13 +33,38 @@ def __init__(self,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
decoder_conf: Dict = {
'in_channels': 240,
'out_channel': 80,
'spk_emb_dim': 80,
'n_spks': 1,
'cfm_params': SimpleNamespace(
sigma_min=1e-6,
solver='euler',
t_scheduler='cosine',
training_cfg_rate=0.2,
inference_cfg_rate=0.7,
reg_loss_type='l1',
),
'decoder_params': {
'channels': [256, 256],
'dropout': 0.0,
'attention_head_dim': 64,
'n_blocks': 4,
'num_mid_blocks': 12,
'num_heads': 8,
'act_fn': 'gelu',
},
},
mel_feat_conf: Dict = {
'n_fft': 1024,
'num_mels': 80,
'sampling_rate': 22050,
'hop_size': 256,
'win_size': 1024,
'fmin': 0,
'fmax': 8000,
}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
Expand All @@ -56,6 +81,7 @@ def __init__(self,
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
self.fp16 = False # (preserves existing behaviour)

def forward(
self,
Expand Down Expand Up @@ -111,7 +137,7 @@ def inference(self,
prompt_feat_len,
embedding,
flow_cache):
if self.fp16 is True:
if self.fp16:
prompt_feat = prompt_feat.half()
embedding = embedding.half()

Expand All @@ -129,8 +155,11 @@ def inference(self,
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
mel_len1 = prompt_feat.shape[1]
mel_len2 = int(token_len2 / self.input_frame_rate * 22050 / 256)
h, h_lengths = self.length_regulator.inference(
h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate
)

# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
Expand Down Expand Up @@ -165,13 +194,38 @@ def __init__(self,
pre_lookahead_len: int = 3,
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
decoder_conf: Dict = {
'in_channels': 240,
'out_channel': 80,
'spk_emb_dim': 80,
'n_spks': 1,
'cfm_params': SimpleNamespace(
sigma_min=1e-6,
solver='euler',
t_scheduler='cosine',
training_cfg_rate=0.2,
inference_cfg_rate=0.7,
reg_loss_type='l1',
),
'decoder_params': {
'channels': [256, 256],
'dropout': 0.0,
'attention_head_dim': 64,
'n_blocks': 4,
'num_mid_blocks': 12,
'num_heads': 8,
'act_fn': 'gelu',
},
},
mel_feat_conf: Dict = {
'n_fft': 1024,
'num_mels': 80,
'sampling_rate': 22050,
'hop_size': 256,
'win_size': 1024,
'fmin': 0,
'fmax': 8000,
}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
Expand All @@ -189,9 +243,7 @@ def __init__(self,
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len

# FIXME: this was missing - just putting it in as false
self.fp16 = False
self.fp16 = False # ensures attribute exists

@torch.inference_mode()
def inference(self,
Expand All @@ -203,7 +255,7 @@ def inference(self,
prompt_feat_len,
embedding,
finalize):
if self.fp16 is True:
if self.fp16:
prompt_feat = prompt_feat.half()
embedding = embedding.half()

Expand All @@ -219,9 +271,10 @@ def inference(self,

# text encode
h, h_lengths = self.encoder(token, token_len)
if finalize is False:
if not finalize:
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
mel_len1 = prompt_feat.shape[1]
mel_len2 = h.shape[1] - mel_len1
h = self.encoder_proj(h)

# get conditions
Expand Down
18 changes: 9 additions & 9 deletions src/chatterbox/models/s3gen/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
import torch
import torch.nn.functional as F
from .matcha.flow_matching import BASECFM
from omegaconf import OmegaConf
from types import SimpleNamespace


CFM_PARAMS = OmegaConf.create({
"sigma_min": 1e-06,
"solver": "euler",
"t_scheduler": "cosine",
"training_cfg_rate": 0.2,
"inference_cfg_rate": 0.7,
"reg_loss_type": "l1"
})
CFM_PARAMS = SimpleNamespace(
sigma_min=1e-6,
solver="euler",
t_scheduler="cosine",
training_cfg_rate=0.2,
inference_cfg_rate=0.7,
reg_loss_type="l1",
)


class ConditionalCFM(BASECFM):
Expand Down
18 changes: 9 additions & 9 deletions src/chatterbox/models/s3gen/s3gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torchaudio as ta
from functools import lru_cache
from typing import Optional
from omegaconf import DictConfig
from types import SimpleNamespace

from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
from .const import S3GEN_SR
Expand Down Expand Up @@ -85,14 +85,14 @@ def __init__(self):
num_heads=8,
act_fn='gelu',
)
cfm_params = DictConfig({
"sigma_min": 1e-06,
"solver": 'euler',
"t_scheduler": 'cosine',
"training_cfg_rate": 0.2,
"inference_cfg_rate": 0.7,
"reg_loss_type": 'l1',
})
cfm_params = SimpleNamespace(
sigma_min=1e-6,
solver='euler',
t_scheduler='cosine',
training_cfg_rate=0.2,
inference_cfg_rate=0.7,
reg_loss_type='l1',
)
decoder = CausalConditionalCFM(
spk_emb_dim=80,
cfm_params=cfm_params,
Expand Down