mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Persist the train of thought in the conversation history
This commit is contained in:
parent
9e8ac7f89e
commit
a121d67b10
5 changed files with 117 additions and 50 deletions
|
@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area";
|
|||
|
||||
import { InlineLoading } from "../loading/loading";
|
||||
|
||||
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
|
||||
import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react";
|
||||
|
||||
import AgentProfileCard from "../profileCard/profileCard";
|
||||
import { getIconFromIconName } from "@/app/common/iconUtils";
|
||||
import { AgentData } from "@/app/agents/page";
|
||||
import React from "react";
|
||||
import { useIsMobileWidth } from "@/app/common/utils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
interface ChatResponse {
|
||||
status: string;
|
||||
|
@ -40,26 +41,51 @@ interface ChatHistoryProps {
|
|||
customClassName?: string;
|
||||
}
|
||||
|
||||
function constructTrainOfThought(
|
||||
trainOfThought: string[],
|
||||
lastMessage: boolean,
|
||||
agentColor: string,
|
||||
key: string,
|
||||
completed: boolean = false,
|
||||
) {
|
||||
const lastIndex = trainOfThought.length - 1;
|
||||
return (
|
||||
<div className={`${styles.trainOfThought} shadow-sm`} key={key}>
|
||||
{!completed && <InlineLoading className="float-right" />}
|
||||
interface TrainOfThoughtComponentProps {
|
||||
trainOfThought: string[];
|
||||
lastMessage: boolean;
|
||||
agentColor: string;
|
||||
key: string;
|
||||
completed?: boolean;
|
||||
}
|
||||
|
||||
{trainOfThought.map((train, index) => (
|
||||
<TrainOfThought
|
||||
key={`train-${index}`}
|
||||
message={train}
|
||||
primary={index === lastIndex && lastMessage && !completed}
|
||||
agentColor={agentColor}
|
||||
/>
|
||||
))}
|
||||
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
||||
const lastIndex = props.trainOfThought.length - 1;
|
||||
const [collapsed, setCollapsed] = useState(props.completed);
|
||||
|
||||
return (
|
||||
<div className={`${!collapsed ? styles.trainOfThought : ""} shadow-sm`} key={props.key}>
|
||||
{!props.completed && <InlineLoading className="float-right" />}
|
||||
{collapsed ? (
|
||||
<Button
|
||||
className="w-fit text-left justify-start content-start text-xs"
|
||||
onClick={() => setCollapsed(false)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
What was my train of thought?
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
className="w-fit text-left justify-start content-start text-xs"
|
||||
onClick={() => setCollapsed(true)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<XCircle size={16} className="mr-1" />
|
||||
Close
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{!collapsed &&
|
||||
props.trainOfThought.map((train, index) => (
|
||||
<TrainOfThought
|
||||
key={`train-${index}`}
|
||||
message={train}
|
||||
primary={index === lastIndex && props.lastMessage && !props.completed}
|
||||
agentColor={props.agentColor}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -265,25 +291,39 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
{data &&
|
||||
data.chat &&
|
||||
data.chat.map((chatMessage, index) => (
|
||||
<ChatMessage
|
||||
key={`${index}fullHistory`}
|
||||
ref={
|
||||
// attach ref to the second last message to handle scroll on page load
|
||||
index === data.chat.length - 2
|
||||
? latestUserMessageRef
|
||||
: // attach ref to the newest fetched message to handle scroll on fetch
|
||||
// note: stabilize index selection against last page having less messages than fetchMessageCount
|
||||
index ===
|
||||
data.chat.length - (currentPage - 1) * fetchMessageCount
|
||||
? latestFetchedMessageRef
|
||||
: null
|
||||
}
|
||||
isMobileWidth={isMobileWidth}
|
||||
chatMessage={chatMessage}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
isLastMessage={index === data.chat.length - 1}
|
||||
/>
|
||||
<>
|
||||
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
|
||||
<TrainOfThoughtComponent
|
||||
trainOfThought={chatMessage.trainOfThought?.map(
|
||||
(train) => train.data,
|
||||
)}
|
||||
lastMessage={false}
|
||||
agentColor={data?.agent?.color || "orange"}
|
||||
key={`${index}trainOfThought`}
|
||||
completed={true}
|
||||
/>
|
||||
)}
|
||||
<ChatMessage
|
||||
key={`${index}fullHistory`}
|
||||
ref={
|
||||
// attach ref to the second last message to handle scroll on page load
|
||||
index === data.chat.length - 2
|
||||
? latestUserMessageRef
|
||||
: // attach ref to the newest fetched message to handle scroll on fetch
|
||||
// note: stabilize index selection against last page having less messages than fetchMessageCount
|
||||
index ===
|
||||
data.chat.length -
|
||||
(currentPage - 1) * fetchMessageCount
|
||||
? latestFetchedMessageRef
|
||||
: null
|
||||
}
|
||||
isMobileWidth={isMobileWidth}
|
||||
chatMessage={chatMessage}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
isLastMessage={index === data.chat.length - 1}
|
||||
/>
|
||||
</>
|
||||
))}
|
||||
{props.incomingMessages &&
|
||||
props.incomingMessages.map((message, index) => {
|
||||
|
@ -305,14 +345,15 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
/>
|
||||
{message.trainOfThought &&
|
||||
constructTrainOfThought(
|
||||
message.trainOfThought,
|
||||
index === incompleteIncomingMessageIndex,
|
||||
data?.agent?.color || "orange",
|
||||
`${index}trainOfThought`,
|
||||
message.completed,
|
||||
)}
|
||||
{message.trainOfThought && (
|
||||
<TrainOfThoughtComponent
|
||||
trainOfThought={message.trainOfThought}
|
||||
lastMessage={index === incompleteIncomingMessageIndex}
|
||||
agentColor={data?.agent?.color || "orange"}
|
||||
key={`${index}trainOfThought`}
|
||||
completed={message.completed}
|
||||
/>
|
||||
)}
|
||||
<ChatMessage
|
||||
key={`${index}incoming`}
|
||||
isMobileWidth={isMobileWidth}
|
||||
|
|
|
@ -128,6 +128,11 @@ interface Intent {
|
|||
"inferred-queries": string[];
|
||||
}
|
||||
|
||||
interface TrainOfThoughtObject {
|
||||
type: string;
|
||||
data: string;
|
||||
}
|
||||
|
||||
export interface SingleChatMessage {
|
||||
automationId: string;
|
||||
by: string;
|
||||
|
@ -136,6 +141,7 @@ export interface SingleChatMessage {
|
|||
context: Context[];
|
||||
onlineContext: OnlineContext;
|
||||
codeContext: CodeContext;
|
||||
trainOfThought?: TrainOfThoughtObject[];
|
||||
rawQuery?: string;
|
||||
intent?: Intent;
|
||||
agent?: AgentData;
|
||||
|
|
|
@ -146,7 +146,12 @@ class ChatEvent(Enum):
|
|||
|
||||
|
||||
def message_to_log(
|
||||
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
||||
user_message,
|
||||
chat_response,
|
||||
user_message_metadata={},
|
||||
khoj_message_metadata={},
|
||||
conversation_log=[],
|
||||
train_of_thought=[],
|
||||
):
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
|
@ -182,6 +187,7 @@ def save_to_conversation_log(
|
|||
automation_id: str = None,
|
||||
query_images: List[str] = None,
|
||||
tracer: Dict[str, Any] = {},
|
||||
train_of_thought: List[Any] = [],
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
|
@ -197,8 +203,10 @@ def save_to_conversation_log(
|
|||
"onlineContext": online_results,
|
||||
"codeContext": code_results,
|
||||
"automationId": automation_id,
|
||||
"trainOfThought": train_of_thought,
|
||||
},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
ConversationAdapters.save_conversation(
|
||||
user,
|
||||
|
|
|
@ -570,7 +570,9 @@ async def chat(
|
|||
user: KhojUser = request.user.object
|
||||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
train_of_thought = []
|
||||
nonlocal conversation_id
|
||||
|
||||
tracer: dict = {
|
||||
"mid": f"{uuid.uuid4()}",
|
||||
"cid": conversation_id,
|
||||
|
@ -590,7 +592,7 @@ async def chat(
|
|||
uploaded_images.append(uploaded_image)
|
||||
|
||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||
nonlocal connection_alive, ttft
|
||||
nonlocal connection_alive, ttft, train_of_thought
|
||||
if not connection_alive or await request.is_disconnected():
|
||||
connection_alive = False
|
||||
logger.warning(f"User {user} disconnected from {common.client} client")
|
||||
|
@ -598,8 +600,11 @@ async def chat(
|
|||
try:
|
||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||
collect_telemetry()
|
||||
if event_type == ChatEvent.START_LLM_RESPONSE:
|
||||
elif event_type == ChatEvent.START_LLM_RESPONSE:
|
||||
ttft = time.perf_counter() - start_time
|
||||
elif event_type == ChatEvent.STATUS:
|
||||
train_of_thought.append({"type": event_type.value, "data": data})
|
||||
|
||||
if event_type == ChatEvent.MESSAGE:
|
||||
yield data
|
||||
elif event_type == ChatEvent.REFERENCES or stream:
|
||||
|
@ -810,6 +815,7 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -854,6 +860,7 @@ async def chat(
|
|||
automation_id=automation.id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
|
@ -1061,6 +1068,7 @@ async def chat(
|
|||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
|
@ -1118,6 +1126,7 @@ async def chat(
|
|||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
|
@ -1144,6 +1153,7 @@ async def chat(
|
|||
researched_results,
|
||||
uploaded_images,
|
||||
tracer,
|
||||
train_of_thought,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
|
|
|
@ -1113,6 +1113,7 @@ def generate_chat_response(
|
|||
meta_research: str = "",
|
||||
query_images: Optional[List[str]] = None,
|
||||
tracer: dict = {},
|
||||
train_of_thought: List[Any] = [],
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -1137,6 +1138,7 @@ def generate_chat_response(
|
|||
conversation_id=conversation_id,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
|
|
Loading…
Reference in a new issue