File tree 1 file changed +15
-4
lines changed
1 file changed +15
-4
lines changed Original file line number Diff line number Diff line change @@ -35,10 +35,21 @@ def get_timestep_embedding(
35
35
"""
36
36
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
37
37
38
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
39
- These may be fractional.
40
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
41
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
38
+ Args
39
+ timesteps (torch.Tensor):
40
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
41
+ embedding_dim (int):
42
+ the dimension of the output.
43
+ flip_sin_to_cos (bool):
44
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
45
+ downscale_freq_shift (float):
46
+ Controls the delta between frequencies between dimensions
47
+ scale (float):
48
+ Scaling factor applied to the embeddings.
49
+ max_period (int):
50
+ Controls the maximum frequency of the embeddings
51
+ Returns
52
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
42
53
"""
43
54
assert len (timesteps .shape ) == 1 , "Timesteps should be a 1d-array"
44
55
You can’t perform that action at this time.
0 commit comments