mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2025-04-17 18:18:11 +00:00
[FEAT] Fork chat to new thread (#1788)
* implement thread forking feature * rename thread based on forked message * refactor bulk message create for thread fork + bump prisma version * revert prisma version bump * add todo to bulkCreate function in workspace chats * cast user input to expected type to prevent prisma injection * refactor: update order of ops for thread fork --------- Co-authored-by: Timothy Carambat <rambat1010@gmail.com>
This commit is contained in:
parent
a87014822a
commit
8b5d9ccdb3
6 changed files with 157 additions and 1 deletions
frontend/src
components/WorkspaceChat/ChatContainer/ChatHistory
models
server
|
@ -6,6 +6,7 @@ import {
|
|||
ThumbsDown,
|
||||
ArrowsClockwise,
|
||||
Copy,
|
||||
GitMerge,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Tooltip } from "react-tooltip";
|
||||
import Workspace from "@/models/workspace";
|
||||
|
@ -19,6 +20,7 @@ const Actions = ({
|
|||
slug,
|
||||
isLastMessage,
|
||||
regenerateMessage,
|
||||
forkThread,
|
||||
isEditing,
|
||||
role,
|
||||
}) => {
|
||||
|
@ -32,8 +34,14 @@ const Actions = ({
|
|||
|
||||
return (
|
||||
<div className="flex w-full justify-between items-center">
|
||||
<div className="flex justify-start items-center gap-x-4">
|
||||
<div className="flex justify-start items-center gap-x-4 group">
|
||||
<CopyMessage message={message} />
|
||||
<ForkThread
|
||||
chatId={chatId}
|
||||
forkThread={forkThread}
|
||||
isEditing={isEditing}
|
||||
role={role}
|
||||
/>
|
||||
<EditMessageAction chatId={chatId} role={role} isEditing={isEditing} />
|
||||
{isLastMessage && !isEditing && (
|
||||
<RegenerateMessage
|
||||
|
@ -150,5 +158,27 @@ function RegenerateMessage({ regenerateMessage, chatId }) {
|
|||
</div>
|
||||
);
|
||||
}
|
||||
function ForkThread({ chatId, forkThread, isEditing, role }) {
|
||||
if (!chatId || isEditing || role === "user") return null;
|
||||
return (
|
||||
<div className="mt-3 relative">
|
||||
<button
|
||||
onClick={() => forkThread(chatId)}
|
||||
data-tooltip-id="fork-thread"
|
||||
data-tooltip-content="Fork chat to new thread"
|
||||
className="border-none text-zinc-300"
|
||||
aria-label="Fork"
|
||||
>
|
||||
<GitMerge size={18} className="mb-1" weight="fill" />
|
||||
</button>
|
||||
<Tooltip
|
||||
id="fork-thread"
|
||||
place="bottom"
|
||||
delayShow={300}
|
||||
className="tooltip !text-xs"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(Actions);
|
||||
|
|
|
@ -23,6 +23,7 @@ const HistoricalMessage = ({
|
|||
isLastMessage = false,
|
||||
regenerateMessage,
|
||||
saveEditedMessage,
|
||||
forkThread,
|
||||
}) => {
|
||||
const { isEditing } = useEditMessage({ chatId, role });
|
||||
const adjustTextArea = (event) => {
|
||||
|
@ -95,6 +96,7 @@ const HistoricalMessage = ({
|
|||
regenerateMessage={regenerateMessage}
|
||||
isEditing={isEditing}
|
||||
role={role}
|
||||
forkThread={forkThread}
|
||||
/>
|
||||
</div>
|
||||
{role === "assistant" && <Citations sources={sources} />}
|
||||
|
|
|
@ -9,6 +9,7 @@ import useUser from "@/hooks/useUser";
|
|||
import Chartable from "./Chartable";
|
||||
import Workspace from "@/models/workspace";
|
||||
import { useParams } from "react-router-dom";
|
||||
import paths from "@/utils/paths";
|
||||
|
||||
export default function ChatHistory({
|
||||
history = [],
|
||||
|
@ -131,6 +132,18 @@ export default function ChatHistory({
|
|||
}
|
||||
};
|
||||
|
||||
const forkThread = async (chatId) => {
|
||||
const newThreadSlug = await Workspace.forkThread(
|
||||
workspace.slug,
|
||||
threadSlug,
|
||||
chatId
|
||||
);
|
||||
window.location.href = paths.workspace.thread(
|
||||
workspace.slug,
|
||||
newThreadSlug
|
||||
);
|
||||
};
|
||||
|
||||
if (history.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col h-full md:mt-0 pb-44 md:pb-40 w-full justify-end items-center">
|
||||
|
@ -217,6 +230,7 @@ export default function ChatHistory({
|
|||
regenerateMessage={regenerateAssistantMessage}
|
||||
isLastMessage={isLastBotReply}
|
||||
saveEditedMessage={saveEditedMessage}
|
||||
forkThread={forkThread}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
|
|
@ -384,6 +384,22 @@ const Workspace = {
|
|||
return false;
|
||||
});
|
||||
},
|
||||
forkThread: async function (slug = "", threadSlug = null, chatId = null) {
|
||||
return await fetch(`${API_BASE}/workspace/${slug}/thread/fork`, {
|
||||
method: "POST",
|
||||
headers: baseHeaders(),
|
||||
body: JSON.stringify({ threadSlug, chatId }),
|
||||
})
|
||||
.then((res) => {
|
||||
if (!res.ok) throw new Error("Failed to fork thread.");
|
||||
return res.json();
|
||||
})
|
||||
.then((data) => data.newThreadSlug)
|
||||
.catch((e) => {
|
||||
console.error("Error forking thread:", e);
|
||||
return null;
|
||||
});
|
||||
},
|
||||
threads: WorkspaceThread,
|
||||
};
|
||||
|
||||
|
|
|
@ -31,6 +31,8 @@ const {
|
|||
fetchPfp,
|
||||
} = require("../utils/files/pfp");
|
||||
const { getTTSProvider } = require("../utils/TextToSpeech");
|
||||
const { WorkspaceThread } = require("../models/workspaceThread");
|
||||
const truncate = require("truncate");
|
||||
|
||||
function workspaceEndpoints(app) {
|
||||
if (!app) return;
|
||||
|
@ -761,6 +763,81 @@ function workspaceEndpoints(app) {
|
|||
}
|
||||
}
|
||||
);
|
||||
|
||||
app.post(
|
||||
"/workspace/:slug/thread/fork",
|
||||
[validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug],
|
||||
async (request, response) => {
|
||||
try {
|
||||
const user = await userFromSession(request, response);
|
||||
const workspace = response.locals.workspace;
|
||||
const { chatId, threadSlug } = reqBody(request);
|
||||
if (!chatId)
|
||||
return response.status(400).json({ message: "chatId is required" });
|
||||
|
||||
// Get threadId we are branching from if that request body is sent
|
||||
// and is a valid thread slug.
|
||||
const threadId = !!threadSlug
|
||||
? (
|
||||
await WorkspaceThread.get({
|
||||
slug: String(threadSlug),
|
||||
workspace_id: workspace.id,
|
||||
})
|
||||
)?.id ?? null
|
||||
: null;
|
||||
const chatsToFork = await WorkspaceChats.where(
|
||||
{
|
||||
workspaceId: workspace.id,
|
||||
user_id: user?.id,
|
||||
include: true, // only duplicate visible chats
|
||||
thread_id: threadId,
|
||||
id: { lte: Number(chatId) },
|
||||
},
|
||||
null,
|
||||
{ id: "asc" }
|
||||
);
|
||||
|
||||
const { thread: newThread, message: threadError } =
|
||||
await WorkspaceThread.new(workspace, user?.id);
|
||||
if (threadError)
|
||||
return response.status(500).json({ error: threadError });
|
||||
|
||||
let lastMessageText = "";
|
||||
const chatsData = chatsToFork.map((chat) => {
|
||||
const chatResponse = safeJsonParse(chat.response, {});
|
||||
if (chatResponse?.text) lastMessageText = chatResponse.text;
|
||||
|
||||
return {
|
||||
workspaceId: workspace.id,
|
||||
prompt: chat.prompt,
|
||||
response: JSON.stringify(chatResponse),
|
||||
user_id: user?.id,
|
||||
thread_id: newThread.id,
|
||||
};
|
||||
});
|
||||
await WorkspaceChats.bulkCreate(chatsData);
|
||||
await WorkspaceThread.update(newThread, {
|
||||
name: !!lastMessageText
|
||||
? truncate(lastMessageText, 22)
|
||||
: "Forked Thread",
|
||||
});
|
||||
|
||||
await Telemetry.sendTelemetry("thread_forked");
|
||||
await EventLogs.logEvent(
|
||||
"thread_forked",
|
||||
{
|
||||
workspaceName: workspace?.name || "Unknown Workspace",
|
||||
threadName: newThread.name,
|
||||
},
|
||||
user?.id
|
||||
);
|
||||
response.status(200).json({ newThreadSlug: newThread.slug });
|
||||
} catch (e) {
|
||||
console.log(e.message, e);
|
||||
response.status(500).json({ message: "Internal server error" });
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
module.exports = { workspaceEndpoints };
|
||||
|
|
|
@ -240,6 +240,23 @@ const WorkspaceChats = {
|
|||
return false;
|
||||
}
|
||||
},
|
||||
bulkCreate: async function (chatsData) {
|
||||
// TODO: Replace with createMany when we update prisma to latest version
|
||||
// The version of prisma that we are currently using does not support createMany with SQLite
|
||||
try {
|
||||
const createdChats = [];
|
||||
for (const chatData of chatsData) {
|
||||
const chat = await prisma.workspace_chats.create({
|
||||
data: chatData,
|
||||
});
|
||||
createdChats.push(chat);
|
||||
}
|
||||
return { chats: createdChats, message: null };
|
||||
} catch (error) {
|
||||
console.error(error.message);
|
||||
return { chats: null, message: error.message };
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
module.exports = { WorkspaceChats };
|
||||
|
|
Loading…
Add table
Reference in a new issue