@@ -994,7 +994,8 @@ class ChatViewResponseHandler : public BaseResponseHandler {
994
994
return onBufferResponse (removeLeadingWhitespace (respStr), 0 );
995
995
}
996
996
997
- bool getStopGenerating () const override { return m_cllm->m_stopGenerating ; }
997
+ bool getStopGenerating () const override
998
+ { return m_cllm->m_stopGenerating ; }
998
999
999
1000
private:
1000
1001
ChatLLM *m_cllm;
@@ -1162,13 +1163,10 @@ void ChatLLM::reloadModel()
1162
1163
loadModel (m);
1163
1164
}
1164
1165
1165
- class NameResponseHandler : public BaseResponseHandler {
1166
- private:
1167
- // max length of chat names, in words
1168
- static constexpr qsizetype MAX_WORDS = 3 ;
1169
-
1166
+ // This class throws discards the text within thinking tags, for use with chat names and follow-up questions.
1167
+ class SimpleResponseHandler : public BaseResponseHandler {
1170
1168
public:
1171
- NameResponseHandler (ChatLLM *cllm)
1169
+ SimpleResponseHandler (ChatLLM *cllm)
1172
1170
: m_cllm(cllm) {}
1173
1171
1174
1172
void onSplitIntoTwo (const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) override
@@ -1178,15 +1176,40 @@ class NameResponseHandler : public BaseResponseHandler {
1178
1176
{ /* no-op */ }
1179
1177
1180
1178
void onOldResponseChunk (const QByteArray &chunk) override
1181
- {
1182
- m_response.append (chunk);
1183
- }
1179
+ { m_response.append (chunk); }
1184
1180
1185
1181
bool onBufferResponse (const QString &response, int bufferIdx) override
1186
1182
{
1187
1183
if (bufferIdx == 1 )
1188
1184
return true ; // ignore "think" content
1185
+ return onSimpleResponse (response);
1186
+ }
1187
+
1188
+ bool onRegularResponse () override
1189
+ { return onBufferResponse (QString::fromUtf8 (m_response), 0 ); }
1190
+
1191
+ bool getStopGenerating () const override
1192
+ { return m_cllm->m_stopGenerating ; }
1189
1193
1194
+ protected:
1195
+ virtual bool onSimpleResponse (const QString &response) = 0;
1196
+
1197
+ protected:
1198
+ ChatLLM *m_cllm;
1199
+ QByteArray m_response;
1200
+ };
1201
+
1202
+ class NameResponseHandler : public SimpleResponseHandler {
1203
+ private:
1204
+ // max length of chat names, in words
1205
+ static constexpr qsizetype MAX_WORDS = 3 ;
1206
+
1207
+ public:
1208
+ using SimpleResponseHandler::SimpleResponseHandler;
1209
+
1210
+ protected:
1211
+ bool onSimpleResponse (const QString &response) override
1212
+ {
1190
1213
QTextStream stream (const_cast <QString *>(&response), QIODeviceBase::ReadOnly);
1191
1214
QStringList words;
1192
1215
while (!stream.atEnd () && words.size () < MAX_WORDS) {
@@ -1198,17 +1221,6 @@ class NameResponseHandler : public BaseResponseHandler {
1198
1221
emit m_cllm->generatedNameChanged (words.join (u' ' ));
1199
1222
return words.size () < MAX_WORDS || stream.atEnd ();
1200
1223
}
1201
-
1202
- bool onRegularResponse () override
1203
- {
1204
- return onBufferResponse (QString::fromUtf8 (m_response), 0 );
1205
- }
1206
-
1207
- bool getStopGenerating () const override { return m_cllm->m_stopGenerating ; }
1208
-
1209
- private:
1210
- ChatLLM *m_cllm;
1211
- QByteArray m_response;
1212
1224
};
1213
1225
1214
1226
void ChatLLM::generateName ()
@@ -1247,13 +1259,43 @@ void ChatLLM::handleChatIdChanged(const QString &id)
1247
1259
m_llmThread.setObjectName (id);
1248
1260
}
1249
1261
1250
- void ChatLLM::generateQuestions (qint64 elapsed)
1251
- {
1262
+ class QuestionResponseHandler : public SimpleResponseHandler {
1263
+ public:
1264
+ using SimpleResponseHandler::SimpleResponseHandler;
1265
+
1266
+ protected:
1267
+ bool onSimpleResponse (const QString &response) override
1268
+ {
1269
+ auto responseUtf8Bytes = response.toUtf8 ().slice (m_offset);
1270
+ auto responseUtf8 = std::string (responseUtf8Bytes.begin (), responseUtf8Bytes.end ());
1271
+ // extract all questions from response
1272
+ ptrdiff_t lastMatchEnd = -1 ;
1273
+ auto it = std::sregex_iterator (responseUtf8.begin (), responseUtf8.end (), s_reQuestion);
1274
+ auto end = std::sregex_iterator ();
1275
+ for (; it != end; ++it) {
1276
+ auto pos = it->position ();
1277
+ auto len = it->length ();
1278
+ lastMatchEnd = pos + len;
1279
+ emit m_cllm->generatedQuestionFinished (QString::fromUtf8 (&responseUtf8[pos], len));
1280
+ }
1281
+
1282
+ // remove processed input from buffer
1283
+ if (lastMatchEnd != -1 )
1284
+ m_offset += lastMatchEnd;
1285
+ return true ;
1286
+ }
1287
+
1288
+ private:
1252
1289
// FIXME: This only works with response by the model in english which is not ideal for a multi-language
1253
1290
// model.
1254
1291
// match whole question sentences
1255
- static const std::regex reQuestion ( R"( \b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)" ) ;
1292
+ static inline const std::regex s_reQuestion { R"( \b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)" } ;
1256
1293
1294
+ qsizetype m_offset = 0 ;
1295
+ };
1296
+
1297
+ void ChatLLM::generateQuestions (qint64 elapsed)
1298
+ {
1257
1299
Q_ASSERT (isModelLoaded ());
1258
1300
if (!isModelLoaded ()) {
1259
1301
emit responseStopped (elapsed);
@@ -1271,39 +1313,17 @@ void ChatLLM::generateQuestions(qint64 elapsed)
1271
1313
1272
1314
emit generatingQuestions ();
1273
1315
1274
- std::string response; // raw UTF-8
1275
-
1276
- auto handleResponse = [this , &response](LLModel::Token token, std::string_view piece) -> bool {
1277
- Q_UNUSED (token)
1278
-
1279
- // add token to buffer
1280
- response.append (piece);
1281
-
1282
- // extract all questions from response
1283
- ptrdiff_t lastMatchEnd = -1 ;
1284
- auto it = std::sregex_iterator (response.begin (), response.end (), reQuestion);
1285
- auto end = std::sregex_iterator ();
1286
- for (; it != end; ++it) {
1287
- auto pos = it->position ();
1288
- auto len = it->length ();
1289
- lastMatchEnd = pos + len;
1290
- emit generatedQuestionFinished (QString::fromUtf8 (&response[pos], len));
1291
- }
1292
-
1293
- // remove processed input from buffer
1294
- if (lastMatchEnd != -1 )
1295
- response.erase (0 , lastMatchEnd);
1296
- return true ;
1297
- };
1316
+ QuestionResponseHandler respHandler (this );
1298
1317
1299
1318
QElapsedTimer totalTime;
1300
1319
totalTime.start ();
1301
1320
try {
1302
- m_llModelInfo.model ->prompt (
1303
- applyJinjaTemplate (forkConversation (suggestedFollowUpPrompt)),
1304
- [this ](auto &&...) { return !m_stopGenerating; },
1305
- handleResponse,
1306
- promptContextFromSettings (m_modelInfo)
1321
+ promptModelWithTools (
1322
+ m_llModelInfo.model .get (),
1323
+ /* promptCallback*/ [this ](auto &&...) { return !m_stopGenerating; },
1324
+ respHandler, promptContextFromSettings (m_modelInfo),
1325
+ applyJinjaTemplate (forkConversation (suggestedFollowUpPrompt)).c_str (),
1326
+ { ToolCallConstants::ThinkTagName }
1307
1327
);
1308
1328
} catch (const std::exception &e) {
1309
1329
qWarning () << " ChatLLM failed to generate follow-up questions:" << e.what ();
0 commit comments