Store turn id with each chat message. Expose API to delete chat turn

Each chat turn is a user query, khoj response message pair
This commit is contained in:
Debanjum 2024-10-30 04:21:58 -07:00
parent f64f5b3b6e
commit ba15686682
4 changed files with 34 additions and 0 deletions

View file

@ -1348,6 +1348,17 @@ class ConversationAdapters:
conversation.save()
return conversation.file_filters
@staticmethod
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation:
return False
conversation_log = conversation.conversation_log
updated_log = [msg for msg in conversation_log if msg.get("turnId") != turn_id]
conversation.conversation_log = updated_log
conversation.save()
return True
class FileObjectAdapters:
@staticmethod

View file

@ -4,6 +4,7 @@ import math
import mimetypes
import os
import queue
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
@ -232,12 +233,14 @@ def save_to_conversation_log(
train_of_thought: List[Any] = [],
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={
"created": user_message_time,
"images": query_images,
"turnId": turn_id,
},
khoj_message_metadata={
"context": compiled_references,
@ -246,6 +249,7 @@ def save_to_conversation_log(
"codeContext": code_results,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
},
conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought,

View file

@ -38,6 +38,7 @@ from khoj.routers.helpers import (
ChatRequestBody,
CommonQueryParams,
ConversationCommandRateLimiter,
DeleteMessageRequestBody,
agenerate_chat_response,
aget_relevant_information_sources,
aget_relevant_output_modes,
@ -534,6 +535,19 @@ async def set_conversation_title(
)
@api_chat.delete("/conversation/message", response_class=Response)
@requires(["authenticated"])
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
user = request.user.object
success = ConversationAdapters.delete_message_by_turn_id(
user, delete_request.conversation_id, delete_request.turn_id
)
if success:
return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
else:
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
@api_chat.post("")
@requires(["authenticated"])
async def chat(

View file

@ -1264,6 +1264,11 @@ class ChatRequestBody(BaseModel):
create_new: Optional[bool] = False
class DeleteMessageRequestBody(BaseModel):
conversation_id: str
turn_id: str
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests