Fix deleting new messages generated after conversation load

This commit is contained in:
Debanjum 2024-10-30 18:10:43 -07:00
parent cb90abc660
commit e8e6ead39f
8 changed files with 60 additions and 29 deletions

View file

@ -29,6 +29,7 @@ interface ChatBodyDataProps {
onConversationIdChange?: (conversationId: string) => void; onConversationIdChange?: (conversationId: string) => void;
setQueryToProcess: (query: string) => void; setQueryToProcess: (query: string) => void;
streamedMessages: StreamMessage[]; streamedMessages: StreamMessage[];
setStreamedMessages: (messages: StreamMessage[]) => void;
setUploadedFiles: (files: string[]) => void; setUploadedFiles: (files: string[]) => void;
isMobileWidth?: boolean; isMobileWidth?: boolean;
isLoggedIn: boolean; isLoggedIn: boolean;
@ -118,6 +119,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
setAgent={setAgentMetadata} setAgent={setAgentMetadata}
pendingMessage={processingMessage ? message : ""} pendingMessage={processingMessage ? message : ""}
incomingMessages={props.streamedMessages} incomingMessages={props.streamedMessages}
setIncomingMessages={props.setStreamedMessages}
customClassName={chatHistoryCustomClassName} customClassName={chatHistoryCustomClassName}
/> />
</div> </div>
@ -351,6 +353,7 @@ export default function Chat() {
<ChatBodyData <ChatBodyData
isLoggedIn={authenticatedData !== null} isLoggedIn={authenticatedData !== null}
streamedMessages={messages} streamedMessages={messages}
setStreamedMessages={setMessages}
chatOptionsData={chatOptionsData} chatOptionsData={chatOptionsData}
setTitle={setTitle} setTitle={setTitle}
setQueryToProcess={setQueryToProcess} setQueryToProcess={setQueryToProcess}

View file

@ -11,6 +11,11 @@ export interface RawReferenceData {
codeContext?: CodeContext; codeContext?: CodeContext;
} }
export interface MessageMetadata {
conversationId: string;
turnId: string;
}
export interface ResponseWithIntent { export interface ResponseWithIntent {
intentType: string; intentType: string;
response: string; response: string;
@ -90,6 +95,9 @@ export function processMessageChunk(
if (references.onlineContext) onlineContext = references.onlineContext; if (references.onlineContext) onlineContext = references.onlineContext;
if (references.codeContext) codeContext = references.codeContext; if (references.codeContext) codeContext = references.codeContext;
return { context, onlineContext, codeContext }; return { context, onlineContext, codeContext };
} else if (chunk.type === "metadata") {
const messageMetadata = chunk.data as MessageMetadata;
currentMessage.turnId = messageMetadata.turnId;
} else if (chunk.type === "message") { } else if (chunk.type === "message") {
const chunkData = chunk.data; const chunkData = chunk.data;
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw // Here, handle if the response is a JSON response with an image, but the intentType is excalidraw

View file

@ -34,8 +34,9 @@ interface ChatHistory {
interface ChatHistoryProps { interface ChatHistoryProps {
conversationId: string; conversationId: string;
setTitle: (title: string) => void; setTitle: (title: string) => void;
incomingMessages?: StreamMessage[];
pendingMessage?: string; pendingMessage?: string;
incomingMessages?: StreamMessage[];
setIncomingMessages?: (incomingMessages: StreamMessage[]) => void;
publicConversationSlug?: string; publicConversationSlug?: string;
setAgent: (agent: AgentData) => void; setAgent: (agent: AgentData) => void;
customClassName?: string; customClassName?: string;
@ -97,6 +98,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
const [data, setData] = useState<ChatHistoryData | null>(null); const [data, setData] = useState<ChatHistoryData | null>(null);
const [currentPage, setCurrentPage] = useState(0); const [currentPage, setCurrentPage] = useState(0);
const [hasMoreMessages, setHasMoreMessages] = useState(true); const [hasMoreMessages, setHasMoreMessages] = useState(true);
const [currentTurnId, setCurrentTurnId] = useState<string | null>(null);
const sentinelRef = useRef<HTMLDivElement | null>(null); const sentinelRef = useRef<HTMLDivElement | null>(null);
const scrollAreaRef = useRef<HTMLDivElement | null>(null); const scrollAreaRef = useRef<HTMLDivElement | null>(null);
const latestUserMessageRef = useRef<HTMLDivElement | null>(null); const latestUserMessageRef = useRef<HTMLDivElement | null>(null);
@ -177,6 +179,10 @@ export default function ChatHistory(props: ChatHistoryProps) {
if (lastMessage && !lastMessage.completed) { if (lastMessage && !lastMessage.completed) {
setIncompleteIncomingMessageIndex(props.incomingMessages.length - 1); setIncompleteIncomingMessageIndex(props.incomingMessages.length - 1);
props.setTitle(lastMessage.rawQuery); props.setTitle(lastMessage.rawQuery);
// Store the turnId when we get it
if (lastMessage.turnId) {
setCurrentTurnId(lastMessage.turnId);
}
} }
} }
}, [props.incomingMessages]); }, [props.incomingMessages]);
@ -279,6 +285,8 @@ export default function ChatHistory(props: ChatHistoryProps) {
} }
const handleDeleteMessage = (turnId?: string) => { const handleDeleteMessage = (turnId?: string) => {
if (!turnId) return;
setData((prevData) => { setData((prevData) => {
if (!prevData || !turnId) return prevData; if (!prevData || !turnId) return prevData;
return { return {
@ -286,6 +294,13 @@ export default function ChatHistory(props: ChatHistoryProps) {
chat: prevData.chat.filter((msg) => msg.turnId !== turnId), chat: prevData.chat.filter((msg) => msg.turnId !== turnId),
}; };
}); });
// Update incoming messages if they exist
if (props.incomingMessages && props.setIncomingMessages) {
props.setIncomingMessages(
props.incomingMessages.filter((msg) => msg.turnId !== turnId),
);
}
}; };
if (!props.conversationId && !props.publicConversationSlug) { if (!props.conversationId && !props.publicConversationSlug) {
@ -341,6 +356,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
))} ))}
{props.incomingMessages && {props.incomingMessages &&
props.incomingMessages.map((message, index) => { props.incomingMessages.map((message, index) => {
const messageTurnId = message.turnId ?? currentTurnId ?? undefined;
return ( return (
<React.Fragment key={`incomingMessage${index}`}> <React.Fragment key={`incomingMessage${index}`}>
<ChatMessage <ChatMessage
@ -356,13 +372,13 @@ export default function ChatHistory(props: ChatHistoryProps) {
automationId: "", automationId: "",
images: message.images, images: message.images,
conversationId: props.conversationId, conversationId: props.conversationId,
turnId: message.turnId, turnId: messageTurnId,
onDeleteMessage: handleDeleteMessage,
}} }}
customClassName="fullHistory" customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`} borderLeftColor={`${data?.agent?.color}-500`}
onDeleteMessage={handleDeleteMessage} onDeleteMessage={handleDeleteMessage}
conversationId={props.conversationId} conversationId={props.conversationId}
turnId={messageTurnId}
/> />
{message.trainOfThought && ( {message.trainOfThought && (
<TrainOfThoughtComponent <TrainOfThoughtComponent
@ -393,10 +409,10 @@ export default function ChatHistory(props: ChatHistoryProps) {
"inferred-queries": message.inferredQueries || [], "inferred-queries": message.inferredQueries || [],
}, },
conversationId: props.conversationId, conversationId: props.conversationId,
turnId: message.turnId, turnId: messageTurnId,
onDeleteMessage: handleDeleteMessage,
}} }}
conversationId={props.conversationId} conversationId={props.conversationId}
turnId={messageTurnId}
onDeleteMessage={handleDeleteMessage} onDeleteMessage={handleDeleteMessage}
customClassName="fullHistory" customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`} borderLeftColor={`${data?.agent?.color}-500`}
@ -418,7 +434,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
by: "you", by: "you",
automationId: "", automationId: "",
conversationId: props.conversationId, conversationId: props.conversationId,
onDeleteMessage: handleDeleteMessage, turnId: undefined,
}} }}
conversationId={props.conversationId} conversationId={props.conversationId}
onDeleteMessage={handleDeleteMessage} onDeleteMessage={handleDeleteMessage}

View file

@ -149,7 +149,6 @@ export interface SingleChatMessage {
images?: string[]; images?: string[];
conversationId: string; conversationId: string;
turnId?: string; turnId?: string;
onDeleteMessage: (turnId: string) => void;
} }
export interface StreamMessage { export interface StreamMessage {
@ -249,6 +248,7 @@ interface ChatMessageProps {
agent?: AgentData; agent?: AgentData;
onDeleteMessage: (turnId?: string) => void; onDeleteMessage: (turnId?: string) => void;
conversationId: string; conversationId: string;
turnId?: string;
} }
interface TrainOfThoughtProps { interface TrainOfThoughtProps {
@ -662,6 +662,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
} }
const deleteMessage = async (message: SingleChatMessage) => { const deleteMessage = async (message: SingleChatMessage) => {
const turnId = message.turnId || props.turnId;
const response = await fetch("/api/chat/conversation/message", { const response = await fetch("/api/chat/conversation/message", {
method: "DELETE", method: "DELETE",
headers: { headers: {
@ -669,13 +670,13 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
}, },
body: JSON.stringify({ body: JSON.stringify({
conversation_id: props.conversationId, conversation_id: props.conversationId,
turn_id: message.turnId, turn_id: turnId,
}), }),
}); });
if (response.ok) { if (response.ok) {
// Update the UI after successful deletion // Update the UI after successful deletion
props.onDeleteMessage(message.turnId); props.onDeleteMessage(turnId);
} else { } else {
console.error("Failed to delete message"); console.error("Failed to delete message");
} }
@ -743,6 +744,16 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
/> />
</button> </button>
))} ))}
<button
title="Delete"
className={`${styles.deleteButton}`}
onClick={() => deleteMessage(props.chatMessage)}
>
<Trash
alt="Delete Message"
className="hsl(var(--muted-foreground)) hover:text-red-500"
/>
</button>
<button <button
title="Copy" title="Copy"
className={`${styles.copyButton}`} className={`${styles.copyButton}`}
@ -764,16 +775,6 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
/> />
)} )}
</button> </button>
<button
title="Delete"
className={`${styles.deleteButton}`}
onClick={() => deleteMessage(props.chatMessage)}
>
<Trash
alt="Delete Message"
className="hsl(var(--muted-foreground)) hover:text-red-500"
/>
</button>
{props.chatMessage.by === "khoj" && {props.chatMessage.by === "khoj" &&
(props.chatMessage.intent ? ( (props.chatMessage.intent ? (
<FeedbackButtons <FeedbackButtons

View file

@ -197,7 +197,6 @@ function ReferenceVerification(props: ReferenceVerificationProps) {
codeContext: {}, codeContext: {},
conversationId: props.conversationId, conversationId: props.conversationId,
turnId: "", turnId: "",
onDeleteMessage: (turnId?: string) => {},
}} }}
isMobileWidth={isMobileWidth} isMobileWidth={isMobileWidth}
onDeleteMessage={(turnId?: string) => {}} onDeleteMessage={(turnId?: string) => {}}
@ -633,7 +632,6 @@ export default function FactChecker() {
codeContext: {}, codeContext: {},
conversationId: conversationID, conversationId: conversationID,
turnId: "", turnId: "",
onDeleteMessage: (turnId?: string) => {},
}} }}
conversationId={conversationID} conversationId={conversationID}
onDeleteMessage={(turnId?: string) => {}} onDeleteMessage={(turnId?: string) => {}}

View file

@ -186,6 +186,7 @@ class ChatEvent(Enum):
MESSAGE = "message" MESSAGE = "message"
REFERENCES = "references" REFERENCES = "references"
STATUS = "status" STATUS = "status"
METADATA = "metadata"
def message_to_log( def message_to_log(

View file

@ -569,6 +569,7 @@ async def chat(
stream = body.stream stream = body.stream
title = body.title title = body.title
conversation_id = body.conversation_id conversation_id = body.conversation_id
turn_id = str(body.turn_id or uuid.uuid4())
city = body.city city = body.city
region = body.region region = body.region
country = body.country or get_country_name_from_timezone(body.timezone) country = body.country or get_country_name_from_timezone(body.timezone)
@ -588,7 +589,7 @@ async def chat(
nonlocal conversation_id nonlocal conversation_id
tracer: dict = { tracer: dict = {
"mid": f"{uuid.uuid4()}", "mid": turn_id,
"cid": conversation_id, "cid": conversation_id,
"uid": user.id, "uid": user.id,
"khoj_version": state.khoj_version, "khoj_version": state.khoj_version,
@ -621,7 +622,7 @@ async def chat(
if event_type == ChatEvent.MESSAGE: if event_type == ChatEvent.MESSAGE:
yield data yield data
elif event_type == ChatEvent.REFERENCES or stream: elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e: except asyncio.CancelledError as e:
connection_alive = False connection_alive = False
@ -665,6 +666,11 @@ async def chat(
metadata=chat_metadata, metadata=chat_metadata,
) )
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation_commands = [get_conversation_command(query=q, any_references=True)]
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
@ -680,6 +686,9 @@ async def chat(
return return
conversation_id = conversation.id conversation_id = conversation.id
async for event in send_event(ChatEvent.METADATA, {"conversationId": str(conversation_id), "turnId": turn_id}):
yield event
agent: Agent | None = None agent: Agent | None = None
default_agent = await AgentAdapters.aget_default_agent() default_agent = await AgentAdapters.aget_default_agent()
if conversation.agent and conversation.agent != default_agent: if conversation.agent and conversation.agent != default_agent:
@ -691,17 +700,11 @@ async def chat(
agent = default_agent agent = default_agent
await is_ready_to_chat(user) await is_ready_to_chat(user)
user_name = await aget_user_name(user) user_name = await aget_user_name(user)
location = None location = None
if city or region or country or country_code: if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code) location = LocationData(city=city, region=region, country=country, country_code=country_code)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log meta_log = conversation.conversation_log

View file

@ -1255,6 +1255,7 @@ class ChatRequestBody(BaseModel):
stream: Optional[bool] = False stream: Optional[bool] = False
title: Optional[str] = None title: Optional[str] = None
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
turn_id: Optional[str] = None
city: Optional[str] = None city: Optional[str] = None
region: Optional[str] = None region: Optional[str] = None
country: Optional[str] = None country: Optional[str] = None