Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

undesirable behaviour of find_lengths function #138

Open
mbaroni opened this issue Nov 5, 2020 · 3 comments
Open

undesirable behaviour of find_lengths function #138

mbaroni opened this issue Nov 5, 2020 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@mbaroni
Copy link
Contributor

mbaroni commented Nov 5, 2020

def find_lengths(messages: torch.Tensor) -> torch.Tensor:
"""
:param messages: A tensor of term ids, encoded as Long values, of size (batch size, max sequence length).
:returns A tensor with lengths of the sequences, including the end-of-sequence symbol (in EGG, it is 0).
If no is found, the full length is returned (i.e. messages.size(1)).

This leads to counterintuitive behaviour in which, if max_len is 3, [1, 2, 3] and [1, 2, 0] have the same length.

@robertodessi
Copy link
Contributor

robertodessi commented Nov 5, 2020

Quickly thikning about it I see two possible options:

  1. not allowing (read raising an error/throwing and exception or returning a special value like 0 or -1) messages withouth EOS
  2. considering [1, 2, 3] as length 3 and [1, 2, 0] as length 2

I have a preference for 2. but am open to discuss it

@mbaroni mbaroni added the enhancement New feature or request label Nov 13, 2020
@robertodessi robertodessi changed the title undesirable behavious of find_lengths function undesirable behaviour of find_lengths function Feb 1, 2021
@tomkouwenhoven
Copy link

tomkouwenhoven commented Jun 23, 2023

Hi,

Is there any update on this issue? I am also working on variable-length communication, and adding an EOS token to each message (see three code lines below) causes the lengths to be of length: opts.max_len + 1 when the sender itself produces no EOS token. This is a bit counterintuitive when one specifies max_len to be a specific value and observes a length returned by the find_lengths function that is longer than the specified value.

sequence = torch.stack(sequence).permute(1, 0)
zeros = torch.zeros((sequence.size(0), 1)).to(sequence.device)
sequence = torch.cat([sequence, zeros.long()], dim=1)

In terms of the options you mentioned on November 5th 2020 I would also opt for option 2.

Is the current best solution still to increase max_len parameter with one as mentioned here #188 ?

Thanks in advance,
Tom Kouwenhoven

@robertodessi
Copy link
Contributor

Hi,

No concrete plans to work on this in the near future. Do you want to give it a go? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants