Persist the train of thought in the conversation history

This commit is contained in:
sabaimran 2024-10-26 23:46:15 -07:00
parent 9e8ac7f89e
commit a121d67b10
5 changed files with 117 additions and 50 deletions

View file

@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { InlineLoading } from "../loading/loading";
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react";
import AgentProfileCard from "../profileCard/profileCard";
import { getIconFromIconName } from "@/app/common/iconUtils";
import { AgentData } from "@/app/agents/page";
import React from "react";
import { useIsMobileWidth } from "@/app/common/utils";
import { Button } from "@/components/ui/button";
interface ChatResponse {
status: string;
@ -40,26 +41,51 @@ interface ChatHistoryProps {
customClassName?: string;
}
function constructTrainOfThought(
trainOfThought: string[],
lastMessage: boolean,
agentColor: string,
key: string,
completed: boolean = false,
) {
const lastIndex = trainOfThought.length - 1;
return (
<div className={`${styles.trainOfThought} shadow-sm`} key={key}>
{!completed && <InlineLoading className="float-right" />}
interface TrainOfThoughtComponentProps {
trainOfThought: string[];
lastMessage: boolean;
agentColor: string;
key: string;
completed?: boolean;
}
{trainOfThought.map((train, index) => (
<TrainOfThought
key={`train-${index}`}
message={train}
primary={index === lastIndex && lastMessage && !completed}
agentColor={agentColor}
/>
))}
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
const lastIndex = props.trainOfThought.length - 1;
const [collapsed, setCollapsed] = useState(props.completed);
return (
<div className={`${!collapsed ? styles.trainOfThought : ""} shadow-sm`} key={props.key}>
{!props.completed && <InlineLoading className="float-right" />}
{collapsed ? (
<Button
className="w-fit text-left justify-start content-start text-xs"
onClick={() => setCollapsed(false)}
variant="ghost"
size="sm"
>
What was my train of thought?
</Button>
) : (
<Button
className="w-fit text-left justify-start content-start text-xs"
onClick={() => setCollapsed(true)}
variant="ghost"
size="sm"
>
<XCircle size={16} className="mr-1" />
Close
</Button>
)}
{!collapsed &&
props.trainOfThought.map((train, index) => (
<TrainOfThought
key={`train-${index}`}
message={train}
primary={index === lastIndex && props.lastMessage && !props.completed}
agentColor={props.agentColor}
/>
))}
</div>
);
}
@ -265,25 +291,39 @@ export default function ChatHistory(props: ChatHistoryProps) {
{data &&
data.chat &&
data.chat.map((chatMessage, index) => (
<ChatMessage
key={`${index}fullHistory`}
ref={
// attach ref to the second last message to handle scroll on page load
index === data.chat.length - 2
? latestUserMessageRef
: // attach ref to the newest fetched message to handle scroll on fetch
// note: stabilize index selection against last page having less messages than fetchMessageCount
index ===
data.chat.length - (currentPage - 1) * fetchMessageCount
? latestFetchedMessageRef
: null
}
isMobileWidth={isMobileWidth}
chatMessage={chatMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={index === data.chat.length - 1}
/>
<>
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
<TrainOfThoughtComponent
trainOfThought={chatMessage.trainOfThought?.map(
(train) => train.data,
)}
lastMessage={false}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
completed={true}
/>
)}
<ChatMessage
key={`${index}fullHistory`}
ref={
// attach ref to the second last message to handle scroll on page load
index === data.chat.length - 2
? latestUserMessageRef
: // attach ref to the newest fetched message to handle scroll on fetch
// note: stabilize index selection against last page having less messages than fetchMessageCount
index ===
data.chat.length -
(currentPage - 1) * fetchMessageCount
? latestFetchedMessageRef
: null
}
isMobileWidth={isMobileWidth}
chatMessage={chatMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={index === data.chat.length - 1}
/>
</>
))}
{props.incomingMessages &&
props.incomingMessages.map((message, index) => {
@ -305,14 +345,15 @@ export default function ChatHistory(props: ChatHistoryProps) {
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
/>
{message.trainOfThought &&
constructTrainOfThought(
message.trainOfThought,
index === incompleteIncomingMessageIndex,
data?.agent?.color || "orange",
`${index}trainOfThought`,
message.completed,
)}
{message.trainOfThought && (
<TrainOfThoughtComponent
trainOfThought={message.trainOfThought}
lastMessage={index === incompleteIncomingMessageIndex}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
completed={message.completed}
/>
)}
<ChatMessage
key={`${index}incoming`}
isMobileWidth={isMobileWidth}

View file

@ -128,6 +128,11 @@ interface Intent {
"inferred-queries": string[];
}
interface TrainOfThoughtObject {
type: string;
data: string;
}
export interface SingleChatMessage {
automationId: string;
by: string;
@ -136,6 +141,7 @@ export interface SingleChatMessage {
context: Context[];
onlineContext: OnlineContext;
codeContext: CodeContext;
trainOfThought?: TrainOfThoughtObject[];
rawQuery?: string;
intent?: Intent;
agent?: AgentData;

View file

@ -146,7 +146,12 @@ class ChatEvent(Enum):
def message_to_log(
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
user_message,
chat_response,
user_message_metadata={},
khoj_message_metadata={},
conversation_log=[],
train_of_thought=[],
):
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
@ -182,6 +187,7 @@ def save_to_conversation_log(
automation_id: str = None,
query_images: List[str] = None,
tracer: Dict[str, Any] = {},
train_of_thought: List[Any] = [],
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@ -197,8 +203,10 @@ def save_to_conversation_log(
"onlineContext": online_results,
"codeContext": code_results,
"automationId": automation_id,
"trainOfThought": train_of_thought,
},
conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought,
)
ConversationAdapters.save_conversation(
user,

View file

@ -570,7 +570,9 @@ async def chat(
user: KhojUser = request.user.object
event_delimiter = "␃🔚␗"
q = unquote(q)
train_of_thought = []
nonlocal conversation_id
tracer: dict = {
"mid": f"{uuid.uuid4()}",
"cid": conversation_id,
@ -590,7 +592,7 @@ async def chat(
uploaded_images.append(uploaded_image)
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft
nonlocal connection_alive, ttft, train_of_thought
if not connection_alive or await request.is_disconnected():
connection_alive = False
logger.warning(f"User {user} disconnected from {common.client} client")
@ -598,8 +600,11 @@ async def chat(
try:
if event_type == ChatEvent.END_LLM_RESPONSE:
collect_telemetry()
if event_type == ChatEvent.START_LLM_RESPONSE:
elif event_type == ChatEvent.START_LLM_RESPONSE:
ttft = time.perf_counter() - start_time
elif event_type == ChatEvent.STATUS:
train_of_thought.append({"type": event_type.value, "data": data})
if event_type == ChatEvent.MESSAGE:
yield data
elif event_type == ChatEvent.REFERENCES or stream:
@ -810,6 +815,7 @@ async def chat(
conversation_id=conversation_id,
query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
return
@ -854,6 +860,7 @@ async def chat(
automation_id=automation.id,
query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
async for result in send_llm_response(llm_response):
yield result
@ -1061,6 +1068,7 @@ async def chat(
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
content_obj = {
"intentType": intent_type,
@ -1118,6 +1126,7 @@ async def chat(
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
async for result in send_llm_response(json.dumps(content_obj)):
@ -1144,6 +1153,7 @@ async def chat(
researched_results,
uploaded_images,
tracer,
train_of_thought,
)
# Send Response

View file

@ -1113,6 +1113,7 @@ def generate_chat_response(
meta_research: str = "",
query_images: Optional[List[str]] = None,
tracer: dict = {},
train_of_thought: List[Any] = [],
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@ -1137,6 +1138,7 @@ def generate_chat_response(
conversation_id=conversation_id,
query_images=query_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)