-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathpredict.py
More file actions
23 lines (15 loc) · 870 Bytes
/
predict.py
File metadata and controls
23 lines (15 loc) · 870 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from tlxzoo.module.t5 import T5Transform
import tensorlayerx as tlx
from tlxzoo.text.text_conditional_generation import TextForConditionalGeneration
if __name__ == '__main__':
model = TextForConditionalGeneration("t5")
model.load_weights("./demo/text/nmt/t5/model.npz")
model.set_eval()
transform = T5Transform(vocab_file="./demo/text/nmt/t5/spiece.model", source_max_length=128, label_max_length=128)
text = "Plane giants often trade blows on technical matters through advertising in the trade press."
x, y = transform(text, "")
inputs = tlx.convert_to_tensor([x["inputs"]], dtype=tlx.int64)
attention_mask = tlx.convert_to_tensor([x["attention_mask"]], dtype=tlx.int64)
decode_id = model.generate_one(inputs=inputs, attention_mask=attention_mask)
decode_str = transform.ids_to_string(decode_id[0])
print(decode_str)