We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi,
would you mind explaining some hard-coded numbers in the template_entity function from inference.py?
template_entity
inference.py
def template_entity(words, input_TXT, start): # input text -> template words_length = len(words) words_length_list = [len(i) for i in words] input_TXT = [input_TXT]*(5*words_length) input_ids = tokenizer(input_TXT, return_tensors='pt')['input_ids'] model.to(device) template_list = [" is a location entity .", " is a person entity .", " is an organization entity .", " is an other entity .", " is not a named entity ."] entity_dict = {0: 'LOC', 1: 'PER', 2: 'ORG', 3: 'MISC', 4: 'O'} temp_list = [] for i in range(words_length): for j in range(len(template_list)): temp_list.append(words[i]+template_list[j]) output_ids = tokenizer(temp_list, return_tensors='pt', padding=True, truncation=True)['input_ids'] output_ids[:, 0] = 2 output_length_list = [0]*5*words_length for i in range(len(temp_list)//5): base_length = ((tokenizer(temp_list[i * 5], return_tensors='pt', padding=True, truncation=True)['input_ids']).shape)[1] - 4 output_length_list[i*5:i*5+ 5] = [base_length]*5 output_length_list[i*5+4] += 1 score = [1]*5*words_length with torch.no_grad(): output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids[:, :output_ids.shape[1] - 2].to(device))[0] for i in range(output_ids.shape[1] - 3): # print(input_ids.shape) logits = output[:, i, :] logits = logits.softmax(dim=1) # values, predictions = logits.topk(1,dim = 1) logits = logits.to('cpu').numpy() # print(output_ids[:, i+1].item()) for j in range(0, 5*words_length): if i < output_length_list[j]: score[j] = score[j] * logits[j][int(output_ids[j][i + 1])] end = start+(score.index(max(score))//5) # score_list.append(score) return [start, end, entity_dict[(score.index(max(score))%5)], max(score)] #[start_index,end_index,label,score]
I learned from the opened issues that the 5s are the length of the template_list but how about the other numbers?
5
template_list
It would be a great help if you could response to this, thank you in advance!
The text was updated successfully, but these errors were encountered:
Have you solved this problem
Sorry, something went wrong.
No branches or pull requests
Hi,
would you mind explaining some hard-coded numbers in the
template_entity
function frominference.py
?I learned from the opened issues that the
5
s are the length of thetemplate_list
but how about the other numbers?It would be a great help if you could response to this, thank you in advance!
The text was updated successfully, but these errors were encountered: