Skip to content

Commit 9152db0

Browse files
committed
fix follow-up generation too
Signed-off-by: Jared Van Bortel <[email protected]>
1 parent 3923796 commit 9152db0

File tree

2 files changed

+74
-54
lines changed

2 files changed

+74
-54
lines changed

Diff for: gpt4all-chat/src/chatllm.cpp

+73-53
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,8 @@ class ChatViewResponseHandler : public BaseResponseHandler {
994994
return onBufferResponse(removeLeadingWhitespace(respStr), 0);
995995
}
996996

997-
bool getStopGenerating() const override { return m_cllm->m_stopGenerating; }
997+
bool getStopGenerating() const override
998+
{ return m_cllm->m_stopGenerating; }
998999

9991000
private:
10001001
ChatLLM *m_cllm;
@@ -1162,13 +1163,10 @@ void ChatLLM::reloadModel()
11621163
loadModel(m);
11631164
}
11641165

1165-
class NameResponseHandler : public BaseResponseHandler {
1166-
private:
1167-
// max length of chat names, in words
1168-
static constexpr qsizetype MAX_WORDS = 3;
1169-
1166+
// This class throws discards the text within thinking tags, for use with chat names and follow-up questions.
1167+
class SimpleResponseHandler : public BaseResponseHandler {
11701168
public:
1171-
NameResponseHandler(ChatLLM *cllm)
1169+
SimpleResponseHandler(ChatLLM *cllm)
11721170
: m_cllm(cllm) {}
11731171

11741172
void onSplitIntoTwo(const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) override
@@ -1178,15 +1176,40 @@ class NameResponseHandler : public BaseResponseHandler {
11781176
{ /* no-op */ }
11791177

11801178
void onOldResponseChunk(const QByteArray &chunk) override
1181-
{
1182-
m_response.append(chunk);
1183-
}
1179+
{ m_response.append(chunk); }
11841180

11851181
bool onBufferResponse(const QString &response, int bufferIdx) override
11861182
{
11871183
if (bufferIdx == 1)
11881184
return true; // ignore "think" content
1185+
return onSimpleResponse(response);
1186+
}
1187+
1188+
bool onRegularResponse() override
1189+
{ return onBufferResponse(QString::fromUtf8(m_response), 0); }
1190+
1191+
bool getStopGenerating() const override
1192+
{ return m_cllm->m_stopGenerating; }
11891193

1194+
protected:
1195+
virtual bool onSimpleResponse(const QString &response) = 0;
1196+
1197+
protected:
1198+
ChatLLM *m_cllm;
1199+
QByteArray m_response;
1200+
};
1201+
1202+
class NameResponseHandler : public SimpleResponseHandler {
1203+
private:
1204+
// max length of chat names, in words
1205+
static constexpr qsizetype MAX_WORDS = 3;
1206+
1207+
public:
1208+
using SimpleResponseHandler::SimpleResponseHandler;
1209+
1210+
protected:
1211+
bool onSimpleResponse(const QString &response) override
1212+
{
11901213
QTextStream stream(const_cast<QString *>(&response), QIODeviceBase::ReadOnly);
11911214
QStringList words;
11921215
while (!stream.atEnd() && words.size() < MAX_WORDS) {
@@ -1198,17 +1221,6 @@ class NameResponseHandler : public BaseResponseHandler {
11981221
emit m_cllm->generatedNameChanged(words.join(u' '));
11991222
return words.size() < MAX_WORDS || stream.atEnd();
12001223
}
1201-
1202-
bool onRegularResponse() override
1203-
{
1204-
return onBufferResponse(QString::fromUtf8(m_response), 0);
1205-
}
1206-
1207-
bool getStopGenerating() const override { return m_cllm->m_stopGenerating; }
1208-
1209-
private:
1210-
ChatLLM *m_cllm;
1211-
QByteArray m_response;
12121224
};
12131225

12141226
void ChatLLM::generateName()
@@ -1247,13 +1259,43 @@ void ChatLLM::handleChatIdChanged(const QString &id)
12471259
m_llmThread.setObjectName(id);
12481260
}
12491261

1250-
void ChatLLM::generateQuestions(qint64 elapsed)
1251-
{
1262+
class QuestionResponseHandler : public SimpleResponseHandler {
1263+
public:
1264+
using SimpleResponseHandler::SimpleResponseHandler;
1265+
1266+
protected:
1267+
bool onSimpleResponse(const QString &response) override
1268+
{
1269+
auto responseUtf8Bytes = response.toUtf8().slice(m_offset);
1270+
auto responseUtf8 = std::string(responseUtf8Bytes.begin(), responseUtf8Bytes.end());
1271+
// extract all questions from response
1272+
ptrdiff_t lastMatchEnd = -1;
1273+
auto it = std::sregex_iterator(responseUtf8.begin(), responseUtf8.end(), s_reQuestion);
1274+
auto end = std::sregex_iterator();
1275+
for (; it != end; ++it) {
1276+
auto pos = it->position();
1277+
auto len = it->length();
1278+
lastMatchEnd = pos + len;
1279+
emit m_cllm->generatedQuestionFinished(QString::fromUtf8(&responseUtf8[pos], len));
1280+
}
1281+
1282+
// remove processed input from buffer
1283+
if (lastMatchEnd != -1)
1284+
m_offset += lastMatchEnd;
1285+
return true;
1286+
}
1287+
1288+
private:
12521289
// FIXME: This only works with response by the model in english which is not ideal for a multi-language
12531290
// model.
12541291
// match whole question sentences
1255-
static const std::regex reQuestion(R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)");
1292+
static inline const std::regex s_reQuestion { R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)" };
12561293

1294+
qsizetype m_offset = 0;
1295+
};
1296+
1297+
void ChatLLM::generateQuestions(qint64 elapsed)
1298+
{
12571299
Q_ASSERT(isModelLoaded());
12581300
if (!isModelLoaded()) {
12591301
emit responseStopped(elapsed);
@@ -1271,39 +1313,17 @@ void ChatLLM::generateQuestions(qint64 elapsed)
12711313

12721314
emit generatingQuestions();
12731315

1274-
std::string response; // raw UTF-8
1275-
1276-
auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool {
1277-
Q_UNUSED(token)
1278-
1279-
// add token to buffer
1280-
response.append(piece);
1281-
1282-
// extract all questions from response
1283-
ptrdiff_t lastMatchEnd = -1;
1284-
auto it = std::sregex_iterator(response.begin(), response.end(), reQuestion);
1285-
auto end = std::sregex_iterator();
1286-
for (; it != end; ++it) {
1287-
auto pos = it->position();
1288-
auto len = it->length();
1289-
lastMatchEnd = pos + len;
1290-
emit generatedQuestionFinished(QString::fromUtf8(&response[pos], len));
1291-
}
1292-
1293-
// remove processed input from buffer
1294-
if (lastMatchEnd != -1)
1295-
response.erase(0, lastMatchEnd);
1296-
return true;
1297-
};
1316+
QuestionResponseHandler respHandler(this);
12981317

12991318
QElapsedTimer totalTime;
13001319
totalTime.start();
13011320
try {
1302-
m_llModelInfo.model->prompt(
1303-
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)),
1304-
[this](auto &&...) { return !m_stopGenerating; },
1305-
handleResponse,
1306-
promptContextFromSettings(m_modelInfo)
1321+
promptModelWithTools(
1322+
m_llModelInfo.model.get(),
1323+
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
1324+
respHandler, promptContextFromSettings(m_modelInfo),
1325+
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(),
1326+
{ ToolCallConstants::ThinkTagName }
13071327
);
13081328
} catch (const std::exception &e) {
13091329
qWarning() << "ChatLLM failed to generate follow-up questions:" << e.what();

Diff for: gpt4all-chat/src/chatllm.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ public Q_SLOTS:
287287
bool m_forceMetal;
288288
bool m_reloadingToChangeVariant;
289289
friend class ChatViewResponseHandler;
290-
friend class NameResponseHandler;
290+
friend class SimpleResponseHandler;
291291
};
292292

293293
#endif // CHATLLM_H

0 commit comments

Comments
 (0)