[FEAT] Fork chat to new thread ()

* implement thread forking feature

* rename thread based on forked message

* refactor bulk message create for thread fork + bump prisma version

* revert prisma version bump

* add todo to bulkCreate function in workspace chats

* cast user input to expected type to prevent prisma injection

* refactor: update order of ops for thread fork

---------

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-07-03 14:44:35 -07:00 committed by GitHub
parent a87014822a
commit 8b5d9ccdb3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 157 additions and 1 deletions
frontend/src
components/WorkspaceChat/ChatContainer/ChatHistory
HistoricalMessage
index.jsx
models
server

View file

@ -6,6 +6,7 @@ import {
ThumbsDown,
ArrowsClockwise,
Copy,
GitMerge,
} from "@phosphor-icons/react";
import { Tooltip } from "react-tooltip";
import Workspace from "@/models/workspace";
@ -19,6 +20,7 @@ const Actions = ({
slug,
isLastMessage,
regenerateMessage,
forkThread,
isEditing,
role,
}) => {
@ -32,8 +34,14 @@ const Actions = ({
return (
<div className="flex w-full justify-between items-center">
<div className="flex justify-start items-center gap-x-4">
<div className="flex justify-start items-center gap-x-4 group">
<CopyMessage message={message} />
<ForkThread
chatId={chatId}
forkThread={forkThread}
isEditing={isEditing}
role={role}
/>
<EditMessageAction chatId={chatId} role={role} isEditing={isEditing} />
{isLastMessage && !isEditing && (
<RegenerateMessage
@ -150,5 +158,27 @@ function RegenerateMessage({ regenerateMessage, chatId }) {
</div>
);
}
function ForkThread({ chatId, forkThread, isEditing, role }) {
if (!chatId || isEditing || role === "user") return null;
return (
<div className="mt-3 relative">
<button
onClick={() => forkThread(chatId)}
data-tooltip-id="fork-thread"
data-tooltip-content="Fork chat to new thread"
className="border-none text-zinc-300"
aria-label="Fork"
>
<GitMerge size={18} className="mb-1" weight="fill" />
</button>
<Tooltip
id="fork-thread"
place="bottom"
delayShow={300}
className="tooltip !text-xs"
/>
</div>
);
}
export default memo(Actions);

View file

@ -23,6 +23,7 @@ const HistoricalMessage = ({
isLastMessage = false,
regenerateMessage,
saveEditedMessage,
forkThread,
}) => {
const { isEditing } = useEditMessage({ chatId, role });
const adjustTextArea = (event) => {
@ -95,6 +96,7 @@ const HistoricalMessage = ({
regenerateMessage={regenerateMessage}
isEditing={isEditing}
role={role}
forkThread={forkThread}
/>
</div>
{role === "assistant" && <Citations sources={sources} />}

View file

@ -9,6 +9,7 @@ import useUser from "@/hooks/useUser";
import Chartable from "./Chartable";
import Workspace from "@/models/workspace";
import { useParams } from "react-router-dom";
import paths from "@/utils/paths";
export default function ChatHistory({
history = [],
@ -131,6 +132,18 @@ export default function ChatHistory({
}
};
const forkThread = async (chatId) => {
const newThreadSlug = await Workspace.forkThread(
workspace.slug,
threadSlug,
chatId
);
window.location.href = paths.workspace.thread(
workspace.slug,
newThreadSlug
);
};
if (history.length === 0) {
return (
<div className="flex flex-col h-full md:mt-0 pb-44 md:pb-40 w-full justify-end items-center">
@ -217,6 +230,7 @@ export default function ChatHistory({
regenerateMessage={regenerateAssistantMessage}
isLastMessage={isLastBotReply}
saveEditedMessage={saveEditedMessage}
forkThread={forkThread}
/>
);
})}

View file

@ -384,6 +384,22 @@ const Workspace = {
return false;
});
},
forkThread: async function (slug = "", threadSlug = null, chatId = null) {
return await fetch(`${API_BASE}/workspace/${slug}/thread/fork`, {
method: "POST",
headers: baseHeaders(),
body: JSON.stringify({ threadSlug, chatId }),
})
.then((res) => {
if (!res.ok) throw new Error("Failed to fork thread.");
return res.json();
})
.then((data) => data.newThreadSlug)
.catch((e) => {
console.error("Error forking thread:", e);
return null;
});
},
threads: WorkspaceThread,
};

View file

@ -31,6 +31,8 @@ const {
fetchPfp,
} = require("../utils/files/pfp");
const { getTTSProvider } = require("../utils/TextToSpeech");
const { WorkspaceThread } = require("../models/workspaceThread");
const truncate = require("truncate");
function workspaceEndpoints(app) {
if (!app) return;
@ -761,6 +763,81 @@ function workspaceEndpoints(app) {
}
}
);
app.post(
"/workspace/:slug/thread/fork",
[validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug],
async (request, response) => {
try {
const user = await userFromSession(request, response);
const workspace = response.locals.workspace;
const { chatId, threadSlug } = reqBody(request);
if (!chatId)
return response.status(400).json({ message: "chatId is required" });
// Get threadId we are branching from if that request body is sent
// and is a valid thread slug.
const threadId = !!threadSlug
? (
await WorkspaceThread.get({
slug: String(threadSlug),
workspace_id: workspace.id,
})
)?.id ?? null
: null;
const chatsToFork = await WorkspaceChats.where(
{
workspaceId: workspace.id,
user_id: user?.id,
include: true, // only duplicate visible chats
thread_id: threadId,
id: { lte: Number(chatId) },
},
null,
{ id: "asc" }
);
const { thread: newThread, message: threadError } =
await WorkspaceThread.new(workspace, user?.id);
if (threadError)
return response.status(500).json({ error: threadError });
let lastMessageText = "";
const chatsData = chatsToFork.map((chat) => {
const chatResponse = safeJsonParse(chat.response, {});
if (chatResponse?.text) lastMessageText = chatResponse.text;
return {
workspaceId: workspace.id,
prompt: chat.prompt,
response: JSON.stringify(chatResponse),
user_id: user?.id,
thread_id: newThread.id,
};
});
await WorkspaceChats.bulkCreate(chatsData);
await WorkspaceThread.update(newThread, {
name: !!lastMessageText
? truncate(lastMessageText, 22)
: "Forked Thread",
});
await Telemetry.sendTelemetry("thread_forked");
await EventLogs.logEvent(
"thread_forked",
{
workspaceName: workspace?.name || "Unknown Workspace",
threadName: newThread.name,
},
user?.id
);
response.status(200).json({ newThreadSlug: newThread.slug });
} catch (e) {
console.log(e.message, e);
response.status(500).json({ message: "Internal server error" });
}
}
);
}
module.exports = { workspaceEndpoints };

View file

@ -240,6 +240,23 @@ const WorkspaceChats = {
return false;
}
},
bulkCreate: async function (chatsData) {
// TODO: Replace with createMany when we update prisma to latest version
// The version of prisma that we are currently using does not support createMany with SQLite
try {
const createdChats = [];
for (const chatData of chatsData) {
const chat = await prisma.workspace_chats.create({
data: chatData,
});
createdChats.push(chat);
}
return { chats: createdChats, message: null };
} catch (error) {
console.error(error.message);
return { chats: null, message: error.message };
}
},
};
module.exports = { WorkspaceChats };