AudioLDM 2
Last updated
Last updated
AudioLDM 2 was proposed in by Haohe Liu et al. AudioLDM 2 takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional sound effects, human speech and music.
Inspired by , AudioLDM 2 is a text-to-audio latent diffusion model (LDM) that learns continuous audio representations from text embeddings. Two text encoder models are used to compute the text embeddings from a prompt input: the text-branch of and the encoder of . These text embeddings are then projected to a shared embedding space by an . A language model (LM) is used to auto-regressively predict eight new embedding vectors, conditional on the projected CLAP and Flan-T5 embeddings. The generated embedding vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The of AudioLDM 2 is unique in the sense that it takes two cross-attention embeddings, as opposed to one cross-attention conditioning, as in most other LDMs.
The abstract of the paper is the following:
Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called language of audio (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate new state-of-the-art or competitive performance to previous approaches.
This pipeline was contributed by . The original codebase can be found at .
AudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio generation. The third checkpoint is trained exclusively on text-to-music generation.
All checkpoints share the same model size for the text encoders and VAE. They differ in the size and depth of the UNet. See table below for details on the three checkpoints:
Text-to-audio
350M
1.1B
1150k
Text-to-audio
750M
1.5B
1150k
Text-to-music
350M
1.1B
665k
Descriptive prompt inputs work best: use adjectives to describe the sound (e.g. βhigh qualityβ or βclearβ) and make the prompt context specific (e.g. βwater stream in a forestβ instead of βstreamβ).
Itβs best to use general terms like βcatβ or βdogβ instead of specific names or abstract objects the model may not be familiar with.
Using a negative prompt can significantly improve the quality of the generated waveform, by guiding the generation away from terms that correspond to poor quality audio. Try using a negative prompt of βLow quality.β
The quality of the predicted audio sample can be controlled by the num_inference_steps
argument; higher steps give higher quality audio at the expense of slower inference.
The length of the predicted audio sample can be controlled by varying the audio_length_in_s
argument.
The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation
Multiple waveforms can be generated in one go: set num_waveforms_per_prompt
to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
( vae: AutoencoderKLtext_encoder: ClapModeltext_encoder_2: T5EncoderModelprojection_model: AudioLDM2ProjectionModellanguage_model: GPT2Modeltokenizer: typing.Union[transformers.models.roberta.tokenization_roberta.RobertaTokenizer, transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast]tokenizer_2: typing.Union[transformers.models.t5.tokenization_t5.T5Tokenizer, transformers.models.t5.tokenization_t5_fast.T5TokenizerFast]feature_extractor: ClapFeatureExtractorunet: AudioLDM2UNet2DConditionModelscheduler: KarrasDiffusionSchedulersvocoder: SpeechT5HifiGan )
Parameters
language_model (GPT2Model
) β An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected outputs from the two text encoders.
tokenizer (RobertaTokenizer
) β Tokenizer to tokenize text for the first frozen text-encoder.
tokenizer_2 (T5Tokenizer
) β Tokenizer to tokenize text for the second frozen text-encoder.
feature_extractor (ClapFeatureExtractor
) β Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
vocoder (SpeechT5HifiGan
) β Vocoder of class SpeechT5HifiGan
to convert the mel-spectrogram latents to the final audio waveform.
Pipeline for text-to-audio generation using AudioLDM2.
__call__
Parameters
prompt (str
or List[str]
, optional) β The prompt or prompts to guide audio generation. If not defined, you need to pass prompt_embeds
.
audio_length_in_s (int
, optional, defaults to 10.24) β The length of the generated audio sample in seconds.
num_inference_steps (int
, optional, defaults to 200) β The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.
guidance_scale (float
, optional, defaults to 3.5) β A higher guidance scale value encourages the model to generate audio that is closely linked to the text prompt
at the expense of lower sound quality. Guidance scale is enabled when guidance_scale > 1
.
negative_prompt (str
or List[str]
, optional) β The prompt or prompts to guide what to not include in audio generation. If not defined, you need to pass negative_prompt_embeds
instead. Ignored when not using guidance (guidance_scale < 1
).
num_waveforms_per_prompt (int
, optional, defaults to 1) β The number of waveforms to generate per prompt. If num_waveforms_per_prompt > 1
, then automatic scoring is performed between the generated outputs and the text prompt. This scoring ranks the generated waveforms based on their cosine similarity with the text input in the joint text-audio embedding space.
latents (torch.FloatTensor
, optional) β Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random generator
.
prompt_embeds (torch.FloatTensor
, optional) β Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the prompt
input argument.
negative_prompt_embeds (torch.FloatTensor
, optional) β Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, negative_prompt_embeds
are generated from the negative_prompt
input argument.
generated_prompt_embeds (torch.FloatTensor
, optional) β Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated from prompt
input argument.
negative_generated_prompt_embeds (torch.FloatTensor
, optional) β Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be computed from negative_prompt
input argument.
attention_mask (torch.LongTensor
, optional) β Pre-computed attention mask to be applied to the prompt_embeds
. If not provided, attention mask will be computed from prompt
input argument.
negative_attention_mask (torch.LongTensor
, optional) β Pre-computed attention mask to be applied to the negative_prompt_embeds
. If not provided, attention mask will be computed from negative_prompt
input argument.
max_new_tokens (int
, optional, defaults to None) β Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will be taken from the config of the model.
callback (Callable
, optional) β A function that calls every callback_steps
steps during inference. The function is called with the following arguments: callback(step: int, timestep: int, latents: torch.FloatTensor)
.
callback_steps (int
, optional, defaults to 1) β The frequency at which the callback
function is called. If not specified, the callback is called at every step.
output_type (str
, optional, defaults to "np"
) β The output format of the generated audio. Choose between "np"
to return a NumPy np.ndarray
or "pt"
to return a PyTorch torch.Tensor
object. Set to "latent"
to return the latent diffusion model (LDM) output.
Returns
The call function to the pipeline for generation.
Examples:
Copied
disable_vae_slicing
( )
Disable sliced VAE decoding. If enable_vae_slicing
was previously enabled, this method will go back to computing decoding in one step.
enable_vae_slicing
( )
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
encode_prompt
( promptdevicenum_waveforms_per_promptdo_classifier_free_guidancenegative_prompt = Noneprompt_embeds: typing.Optional[torch.FloatTensor] = Nonenegative_prompt_embeds: typing.Optional[torch.FloatTensor] = Nonegenerated_prompt_embeds: typing.Optional[torch.FloatTensor] = Nonenegative_generated_prompt_embeds: typing.Optional[torch.FloatTensor] = Noneattention_mask: typing.Optional[torch.LongTensor] = Nonenegative_attention_mask: typing.Optional[torch.LongTensor] = Nonemax_new_tokens: typing.Optional[int] = None ) β prompt_embeds (torch.FloatTensor
)
Parameters
prompt (str
or List[str]
, optional) β prompt to be encoded
device (torch.device
) β torch device
num_waveforms_per_prompt (int
) β number of waveforms that should be generated per prompt
do_classifier_free_guidance (bool
) β whether to use classifier free guidance or not
negative_prompt (str
or List[str]
, optional) β The prompt or prompts not to guide the audio generation. If not defined, one has to pass negative_prompt_embeds
instead. Ignored when not using guidance (i.e., ignored if guidance_scale
is less than 1
).
prompt_embeds (torch.FloatTensor
, optional) β Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be computed from prompt
input argument.
negative_prompt_embeds (torch.FloatTensor
, optional) β Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be computed from negative_prompt
input argument.
generated_prompt_embeds (torch.FloatTensor
, optional) β Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated from prompt
input argument.
negative_generated_prompt_embeds (torch.FloatTensor
, optional) β Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be computed from negative_prompt
input argument.
attention_mask (torch.LongTensor
, optional) β Pre-computed attention mask to be applied to the prompt_embeds
. If not provided, attention mask will be computed from prompt
input argument.
negative_attention_mask (torch.LongTensor
, optional) β Pre-computed attention mask to be applied to the negative_prompt_embeds
. If not provided, attention mask will be computed from negative_prompt
input argument.
max_new_tokens (int
, optional, defaults to None) β The number of new tokens to generate with the GPT2 language model.
Returns
prompt_embeds (torch.FloatTensor
)
Text embeddings from the Flan T5 model. attention_mask (torch.LongTensor
): Attention mask to be applied to the prompt_embeds
. generated_prompt_embeds (torch.FloatTensor
): Text embeddings generated from the GPT2 langauge model.
Encodes the prompt into text encoder hidden states.
Example:
Copied
generate_language_model
( inputs_embeds: Tensor = Nonemax_new_tokens: int = 8**model_kwargs ) β inputs_embeds (
torch.FloatTensorof shape
(batch_size, sequence_length, hidden_size)`)
Parameters
inputs_embeds (torch.FloatTensor
of shape (batch_size, sequence_length, hidden_size)
) β The sequence used as a prompt for the generation.
max_new_tokens (int
) β Number of new tokens to generate.
model_kwargs (Dict[str, Any]
, optional) β Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the forward
function of the model.
Returns
inputs_embeds (
torch.FloatTensorof shape
(batch_size, sequence_length, hidden_size)`)
The sequence of generated hidden-states.
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
( text_encoder_dimtext_encoder_1_dimlangauge_model_dim )
Parameters
text_encoder_dim (int
) β Dimensionality of the text embeddings from the first text encoder (CLAP).
text_encoder_1_dim (int
) β Dimensionality of the text embeddings from the second text encoder (T5 or VITS).
langauge_model_dim (int
) β Dimensionality of the text embeddings from the language model (GPT2).
A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with _1
refers to that corresponding to the second text encoder. Otherwise, it is from the first.
forward
( hidden_states: typing.Optional[torch.FloatTensor] = Nonehidden_states_1: typing.Optional[torch.FloatTensor] = Noneattention_mask: typing.Optional[torch.LongTensor] = Noneattention_mask_1: typing.Optional[torch.LongTensor] = None )
( sample_size: typing.Optional[int] = Nonein_channels: int = 4out_channels: int = 4flip_sin_to_cos: bool = Truefreq_shift: int = 0down_block_types: typing.Tuple[str] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D')mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn'up_block_types: typing.Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D')only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = Falseblock_out_channels: typing.Tuple[int] = (320, 640, 1280, 1280)layers_per_block: typing.Union[int, typing.Tuple[int]] = 2downsample_padding: int = 1mid_block_scale_factor: float = 1act_fn: str = 'silu'norm_num_groups: typing.Optional[int] = 32norm_eps: float = 1e-05cross_attention_dim: typing.Union[int, typing.Tuple[int]] = 1280transformer_layers_per_block: typing.Union[int, typing.Tuple[int]] = 1attention_head_dim: typing.Union[int, typing.Tuple[int]] = 8num_attention_heads: typing.Union[int, typing.Tuple[int], NoneType] = Noneuse_linear_projection: bool = Falseclass_embed_type: typing.Optional[str] = Nonenum_class_embeds: typing.Optional[int] = Noneupcast_attention: bool = Falseresnet_time_scale_shift: str = 'default'time_embedding_type: str = 'positional'time_embedding_dim: typing.Optional[int] = Nonetime_embedding_act_fn: typing.Optional[str] = Nonetimestep_post_act: typing.Optional[str] = Nonetime_cond_proj_dim: typing.Optional[int] = Noneconv_in_kernel: int = 3conv_out_kernel: int = 3projection_class_embeddings_input_dim: typing.Optional[int] = Noneclass_embeddings_concat: bool = False )
Parameters
sample_size (int
or Tuple[int, int]
, optional, defaults to None
) β Height and width of input/output sample.
in_channels (int
, optional, defaults to 4) β Number of channels in the input sample.
out_channels (int
, optional, defaults to 4) β Number of channels in the output.
flip_sin_to_cos (bool
, optional, defaults to False
) β Whether to flip the sin to cos in the time embedding.
freq_shift (int
, optional, defaults to 0) β The frequency shift to apply to the time embedding.
down_block_types (Tuple[str]
, optional, defaults to ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")
) β The tuple of downsample blocks to use.
mid_block_type (str
, optional, defaults to "UNetMidBlock2DCrossAttn"
) β Block type for middle of UNet, it can only be UNetMidBlock2DCrossAttn
for AudioLDM2.
up_block_types (Tuple[str]
, optional, defaults to ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
) β The tuple of upsample blocks to use.
only_cross_attention (bool
or Tuple[bool]
, optional, default to False
) β Whether to include self-attention in the basic transformer blocks, see BasicTransformerBlock
.
block_out_channels (Tuple[int]
, optional, defaults to (320, 640, 1280, 1280)
) β The tuple of output channels for each block.
layers_per_block (int
, optional, defaults to 2) β The number of layers per block.
downsample_padding (int
, optional, defaults to 1) β The padding to use for the downsampling convolution.
mid_block_scale_factor (float
, optional, defaults to 1.0) β The scale factor to use for the mid block.
act_fn (str
, optional, defaults to "silu"
) β The activation function to use.
norm_num_groups (int
, optional, defaults to 32) β The number of groups to use for the normalization. If None
, normalization and activation layers is skipped in post-processing.
norm_eps (float
, optional, defaults to 1e-5) β The epsilon to use for the normalization.
cross_attention_dim (int
or Tuple[int]
, optional, defaults to 1280) β The dimension of the cross attention features.
transformer_layers_per_block (int
or Tuple[int]
, optional, defaults to 1) β The number of transformer blocks of type BasicTransformerBlock
. Only relevant for CrossAttnDownBlock2D
, CrossAttnUpBlock2D
, UNetMidBlock2DCrossAttn
.
attention_head_dim (int
, optional, defaults to 8) β The dimension of the attention heads.
num_attention_heads (int
, optional) β The number of attention heads. If not defined, defaults to attention_head_dim
resnet_time_scale_shift (str
, optional, defaults to "default"
) β Time scale shift config for ResNet blocks (see ResnetBlock2D
). Choose from default
or scale_shift
.
class_embed_type (str
, optional, defaults to None
) β The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None
, "timestep"
, "identity"
, "projection"
, or "simple_projection"
.
num_class_embeds (int
, optional, defaults to None
) β Input dimension of the learnable embedding matrix to be projected to time_embed_dim
, when performing class conditioning with class_embed_type
equal to None
.
time_embedding_type (str
, optional, defaults to positional
) β The type of position embedding to use for timesteps. Choose from positional
or fourier
.
time_embedding_dim (int
, optional, defaults to None
) β An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (str
, optional, defaults to None
) β Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from silu
, mish
, gelu
, and swish
.
timestep_post_act (str
, optional, defaults to None
) β The second activation function to use in timestep embedding. Choose from silu
, mish
and gelu
.
time_cond_proj_dim (int
, optional, defaults to None
) β The dimension of cond_proj
layer in the timestep embedding.
conv_in_kernel (int
, optional, default to 3
) β The kernel size of conv_in
layer.
conv_out_kernel (int
, optional, default to 3
) β The kernel size of conv_out
layer.
projection_class_embeddings_input_dim (int
, optional) β The dimension of the class_labels
input when class_embed_type="projection"
. Required when class_embed_type="projection"
.
class_embeddings_concat (bool
, optional, defaults to False
) β Whether to concatenate the time embeddings with the class embeddings.
forward
Parameters
sample (torch.FloatTensor
) β The noisy input tensor with the following shape (batch, channel, height, width)
.
timestep (torch.FloatTensor
or float
or int
) β The number of timesteps to denoise an input.
encoder_hidden_states (torch.FloatTensor
) β The encoder hidden states with shape (batch, sequence_length, feature_dim)
.
encoder_attention_mask (torch.Tensor
) β A cross-attention mask of shape (batch, sequence_length)
is applied to encoder_hidden_states
. If True
the mask is kept, otherwise if False
it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to βdiscardβ tokens.
cross_attention_kwargs (dict
, optional) β A kwargs dictionary that if specified is passed along to the AttnProcessor
.
encoder_hidden_states_1 (torch.FloatTensor
, optional) β A second set of encoder hidden states with shape (batch, sequence_length_2, feature_dim_2)
. Can be used to condition the model on a different set of embeddings to encoder_hidden_states
.
encoder_attention_mask_1 (torch.Tensor
, optional) β A cross-attention mask of shape (batch, sequence_length_2)
is applied to encoder_hidden_states_1
. If True
the mask is kept, otherwise if False
it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to βdiscardβ tokens.
Returns
( audios: ndarray )
Parameters
audios (np.ndarray
) β List of denoised audio samples of a NumPy array of shape (batch_size, num_channels, sample_rate)
.
Output class for audio pipelines.
The following example demonstrates how to construct good music generation using the aforementioned tips: .
Make sure to check out the Schedulers to learn how to explore the tradeoff between scheduler speed and quality, and see the section to learn how to efficiently load the same components into multiple pipelines.
vae () β Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder (ClapModel
) β First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model , specifically the variant. The text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to rank generated waveforms against the text prompt by computing similarity scores.
text_encoder_2 (T5EncoderModel
) β Second frozen text-encoder. AudioLDM2 uses the encoder of , specifically the variant.
projection_model () β A trained model used to linearly project the hidden-states from the first and second text encoder models and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are concatenated to give the input to the language model.
unet () β A UNet2DConditionModel
to denoise the encoded audio latents.
scheduler () β A scheduler to be used in combination with unet
to denoise the encoded audio latents. Can be one of , , or .
This model inherits from . Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).
( prompt: typing.Union[str, typing.List[str]] = Noneaudio_length_in_s: typing.Optional[float] = Nonenum_inference_steps: int = 200guidance_scale: float = 3.5negative_prompt: typing.Union[str, typing.List[str], NoneType] = Nonenum_waveforms_per_prompt: typing.Optional[int] = 1eta: float = 0.0generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = Nonelatents: typing.Optional[torch.FloatTensor] = Noneprompt_embeds: typing.Optional[torch.FloatTensor] = Nonenegative_prompt_embeds: typing.Optional[torch.FloatTensor] = Nonegenerated_prompt_embeds: typing.Optional[torch.FloatTensor] = Nonenegative_generated_prompt_embeds: typing.Optional[torch.FloatTensor] = Noneattention_mask: typing.Optional[torch.LongTensor] = Nonenegative_attention_mask: typing.Optional[torch.LongTensor] = Nonemax_new_tokens: typing.Optional[int] = Nonereturn_dict: bool = Truecallback: typing.Union[typing.Callable[[int, int, torch.FloatTensor], NoneType], NoneType] = Nonecallback_steps: typing.Optional[int] = 1cross_attention_kwargs: typing.Union[typing.Dict[str, typing.Any], NoneType] = Noneoutput_type: typing.Optional[str] = 'np' ) β or tuple
eta (float
, optional, defaults to 0.0) β Corresponds to parameter eta (Ξ·) from the paper. Only applies to the , and is ignored in other schedulers.
generator (torch.Generator
or List[torch.Generator]
, optional) β A to make generation deterministic.
return_dict (bool
, optional, defaults to True
) β Whether or not to return a instead of a plain tuple.
cross_attention_kwargs (dict
, optional) β A kwargs dictionary that if specified is passed along to the AttentionProcessor
as defined in .
or tuple
If return_dict
is True
, is returned, otherwise a tuple
is returned where the first element is a list with the generated audio.
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. Compared to the vanilla , this variant optionally includes an additional self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up to two cross-attention embeddings, encoder_hidden_states
and encoder_hidden_states_1
.
This model inherits from . Check the superclass documentation for itβs generic methods implemented for all models (such as downloading or saving).
( sample: FloatTensortimestep: typing.Union[torch.Tensor, float, int]encoder_hidden_states: Tensorclass_labels: typing.Optional[torch.Tensor] = Nonetimestep_cond: typing.Optional[torch.Tensor] = Noneattention_mask: typing.Optional[torch.Tensor] = Nonecross_attention_kwargs: typing.Union[typing.Dict[str, typing.Any], NoneType] = Noneencoder_attention_mask: typing.Optional[torch.Tensor] = Nonereturn_dict: bool = Trueencoder_hidden_states_1: typing.Optional[torch.Tensor] = Noneencoder_attention_mask_1: typing.Optional[torch.Tensor] = None ) β or tuple
return_dict (bool
, optional, defaults to True
) β Whether or not to return a instead of a plain tuple.
or tuple
If return_dict
is True, an is returned, otherwise a tuple
is returned where the first element is the sample tensor.
The forward method.