跳转至

估计参数量和浮点运算次数

import sys
import os
sys.path.append(os.getcwd())
from f5_tts.model import CFM, DiT
import torch
import thop

transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, mel_dim=128)

model = CFM(transformer=transformer)
target_sample_rate = 24000
n_mel_channels = 128
hop_length = 256
duration = 20
frame_length = int(duration * target_sample_rate / hop_length)
text_length = 150

flops, params = thop.profile(
    model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
)
print(f"FLOPs: {flops / 1e9} G")
print(f"Params: {params / 1e6} M")