diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index deb26105..c3d5ff37 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -29,6 +29,7 @@ interface ChatBodyDataProps { onConversationIdChange?: (conversationId: string) => void; setQueryToProcess: (query: string) => void; streamedMessages: StreamMessage[]; + setStreamedMessages: (messages: StreamMessage[]) => void; setUploadedFiles: (files: string[]) => void; isMobileWidth?: boolean; isLoggedIn: boolean; @@ -118,6 +119,7 @@ function ChatBodyData(props: ChatBodyDataProps) { setAgent={setAgentMetadata} pendingMessage={processingMessage ? message : ""} incomingMessages={props.streamedMessages} + setIncomingMessages={props.setStreamedMessages} customClassName={chatHistoryCustomClassName} /> @@ -351,6 +353,7 @@ export default function Chat() { void; - incomingMessages?: StreamMessage[]; pendingMessage?: string; + incomingMessages?: StreamMessage[]; + setIncomingMessages?: (incomingMessages: StreamMessage[]) => void; publicConversationSlug?: string; setAgent: (agent: AgentData) => void; customClassName?: string; @@ -97,6 +98,7 @@ export default function ChatHistory(props: ChatHistoryProps) { const [data, setData] = useState(null); const [currentPage, setCurrentPage] = useState(0); const [hasMoreMessages, setHasMoreMessages] = useState(true); + const [currentTurnId, setCurrentTurnId] = useState(null); const sentinelRef = useRef(null); const scrollAreaRef = useRef(null); const latestUserMessageRef = useRef(null); @@ -177,6 +179,10 @@ export default function ChatHistory(props: ChatHistoryProps) { if (lastMessage && !lastMessage.completed) { setIncompleteIncomingMessageIndex(props.incomingMessages.length - 1); props.setTitle(lastMessage.rawQuery); + // Store the turnId when we get it + if (lastMessage.turnId) { + setCurrentTurnId(lastMessage.turnId); + } } } }, [props.incomingMessages]); @@ -279,6 +285,8 @@ export default function ChatHistory(props: ChatHistoryProps) { } const handleDeleteMessage = (turnId?: string) => { + if (!turnId) return; + setData((prevData) => { if (!prevData || !turnId) return prevData; return { @@ -286,6 +294,13 @@ export default function ChatHistory(props: ChatHistoryProps) { chat: prevData.chat.filter((msg) => msg.turnId !== turnId), }; }); + + // Update incoming messages if they exist + if (props.incomingMessages && props.setIncomingMessages) { + props.setIncomingMessages( + props.incomingMessages.filter((msg) => msg.turnId !== turnId), + ); + } }; if (!props.conversationId && !props.publicConversationSlug) { @@ -341,6 +356,7 @@ export default function ChatHistory(props: ChatHistoryProps) { ))} {props.incomingMessages && props.incomingMessages.map((message, index) => { + const messageTurnId = message.turnId ?? currentTurnId ?? undefined; return ( {message.trainOfThought && ( void; } export interface StreamMessage { @@ -249,6 +248,7 @@ interface ChatMessageProps { agent?: AgentData; onDeleteMessage: (turnId?: string) => void; conversationId: string; + turnId?: string; } interface TrainOfThoughtProps { @@ -662,6 +662,7 @@ const ChatMessage = forwardRef((props, ref) => } const deleteMessage = async (message: SingleChatMessage) => { + const turnId = message.turnId || props.turnId; const response = await fetch("/api/chat/conversation/message", { method: "DELETE", headers: { @@ -669,13 +670,13 @@ const ChatMessage = forwardRef((props, ref) => }, body: JSON.stringify({ conversation_id: props.conversationId, - turn_id: message.turnId, + turn_id: turnId, }), }); if (response.ok) { // Update the UI after successful deletion - props.onDeleteMessage(message.turnId); + props.onDeleteMessage(turnId); } else { console.error("Failed to delete message"); } @@ -743,6 +744,16 @@ const ChatMessage = forwardRef((props, ref) => /> ))} + - {props.chatMessage.by === "khoj" && (props.chatMessage.intent ? ( {}, }} isMobileWidth={isMobileWidth} onDeleteMessage={(turnId?: string) => {}} @@ -633,7 +632,6 @@ export default function FactChecker() { codeContext: {}, conversationId: conversationID, turnId: "", - onDeleteMessage: (turnId?: string) => {}, }} conversationId={conversationID} onDeleteMessage={(turnId?: string) => {}} diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index dd420881..241ed783 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -186,6 +186,7 @@ class ChatEvent(Enum): MESSAGE = "message" REFERENCES = "references" STATUS = "status" + METADATA = "metadata" def message_to_log( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 92c024e9..83dd0c24 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -569,6 +569,7 @@ async def chat( stream = body.stream title = body.title conversation_id = body.conversation_id + turn_id = str(body.turn_id or uuid.uuid4()) city = body.city region = body.region country = body.country or get_country_name_from_timezone(body.timezone) @@ -588,7 +589,7 @@ async def chat( nonlocal conversation_id tracer: dict = { - "mid": f"{uuid.uuid4()}", + "mid": turn_id, "cid": conversation_id, "uid": user.id, "khoj_version": state.khoj_version, @@ -621,7 +622,7 @@ async def chat( if event_type == ChatEvent.MESSAGE: yield data - elif event_type == ChatEvent.REFERENCES or stream: + elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) except asyncio.CancelledError as e: connection_alive = False @@ -665,6 +666,11 @@ async def chat( metadata=chat_metadata, ) + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started."): + yield result + return + conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation = await ConversationAdapters.aget_conversation_by_user( @@ -680,6 +686,9 @@ async def chat( return conversation_id = conversation.id + async for event in send_event(ChatEvent.METADATA, {"conversationId": str(conversation_id), "turnId": turn_id}): + yield event + agent: Agent | None = None default_agent = await AgentAdapters.aget_default_agent() if conversation.agent and conversation.agent != default_agent: @@ -691,17 +700,11 @@ async def chat( agent = default_agent await is_ready_to_chat(user) - user_name = await aget_user_name(user) location = None if city or region or country or country_code: location = LocationData(city=city, region=region, country=country, country_code=country_code) - if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started."): - yield result - return - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") meta_log = conversation.conversation_log diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6aa25c5e..f1b8fe2b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1255,6 +1255,7 @@ class ChatRequestBody(BaseModel): stream: Optional[bool] = False title: Optional[str] = None conversation_id: Optional[str] = None + turn_id: Optional[str] = None city: Optional[str] = None region: Optional[str] = None country: Optional[str] = None