-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathinteractive.py
141 lines (120 loc) · 4.59 KB
/
interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Basic script which allows local human keyboard input to talk to a trained model.
Examples
--------
.. code-block:: shell
python projects/convai2/interactive.py -mf models:convai2/kvmemnn/model
When prompted, chat with the both, you will both be assigned personalities!
Use "[DONE]" to indicate you are done with that chat partner, and want a new one.
"""
from parlai.core.params import ParlaiParser
from parlai.core.agents import create_agent
from parlai.core.worlds import create_task
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent
from parlai.agents.local_human.local_human import LocalHumanAgent
import random
import os
pretrained_model_file = './tmp/psquare/psqaure_original.model'
def setup_args(parser=None):
if parser is None:
parser = ParlaiParser(True, True, 'Interactive chat with a model')
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.add_argument(
'--display-prettify',
type='bool',
default=False,
help='Set to use a prettytable when displaying '
'examples with text candidates',
)
parser.add_argument(
'--display-ignore-fields',
type=str,
default='label_candidates,text_candidates',
help='Do not display these fields',
)
parser.set_defaults(model_file=pretrained_model_file)
LocalHumanAgent.add_cmdline_args(parser)
return parser
def interactive(opt, print_parser=None):
if print_parser is not None:
if print_parser is True and isinstance(opt, ParlaiParser):
print_parser = opt
elif print_parser is False:
print_parser = None
if isinstance(opt, ParlaiParser):
print('[ Deprecated Warning: interactive should be passed opt not Parser ]')
opt = opt.parse_args()
opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
# Create model and assign it to the specified task
agent = create_agent(opt, requireModelExists=True)
world = create_task(opt, agent)
if print_parser:
# Show arguments after loading model
print_parser.opt = agent.opt
print_parser.print_args()
# Create ConvAI2 data so we can assign personas.
convai2_opt = opt.copy()
convai2_opt['task'] = 'convai2:both'
convai2_agent = RepeatLabelAgent(convai2_opt)
convai2_world = create_task(convai2_opt, convai2_agent)
def get_new_personas():
# Find a new episode
while True:
convai2_world.parley()
msg = convai2_world.get_acts()[0]
if msg['episode_done']:
convai2_world.parley()
msg = convai2_world.get_acts()[0]
break
txt = msg.get('text', '').split('\n')
bot_persona = ""
for t in txt:
if t.startswith("partner's persona:"):
print(t.replace("partner's ", 'your '))
if t.startswith('your persona:'):
bot_persona += t + '\n'
print("Enter [DONE] if you want a new partner at any time.")
return bot_persona
# Now run interactive mode, chatting with personas!
cnt = 0
while True:
if cnt == 0:
bot_persona = get_new_personas()
print('BOT PERSONA:')
print(bot_persona.split('\n'))
# Run the parts of world.parley() in turn,
# but insert persona into user message.
acts = world.acts
agents = world.agents
acts[0] = agents[0].act()
# add the persona on to the first message
if agents[0].episode_done():
print("CHAT DONE ")
print("\n... preparing new chat... \n")
cnt = 0
agents[0].episodeDone = False
continue
if cnt == 0:
acts[0]['text'] = bot_persona + acts[0].get('text', 'hi')
agents[1].observe(acts[0])
acts[1] = agents[1].act()
agents[0].observe(acts[1])
world.update_counters()
cnt = cnt + 1
if opt.get('display_examples'):
print("---")
print(world.display())
if __name__ == '__main__':
random.seed(42)
parser = setup_args()
parser.set_params(
dict_lower=True,
batchsize=1,
rank_candidates=False,
model='agents.transmitter.transmitter:TransformerAgent',
init_model_transmitter=pretrained_model_file
)
interactive(parser.parse_args(print_args=False), print_parser=parser)