Store turn id with each chat message. Expose API to delete chat turn

Each chat turn is a user query, khoj response message pair
This commit is contained in:
Debanjum 2024-10-30 04:21:58 -07:00
parent f64f5b3b6e
commit ba15686682
4 changed files with 34 additions and 0 deletions

View file

@ -1348,6 +1348,17 @@ class ConversationAdapters:
conversation.save() conversation.save()
return conversation.file_filters return conversation.file_filters
@staticmethod
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation:
return False
conversation_log = conversation.conversation_log
updated_log = [msg for msg in conversation_log if msg.get("turnId") != turn_id]
conversation.conversation_log = updated_log
conversation.save()
return True
class FileObjectAdapters: class FileObjectAdapters:
@staticmethod @staticmethod

View file

@ -4,6 +4,7 @@ import math
import mimetypes import mimetypes
import os import os
import queue import queue
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@ -232,12 +233,14 @@ def save_to_conversation_log(
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
updated_conversation = message_to_log( updated_conversation = message_to_log(
user_message=q, user_message=q,
chat_response=chat_response, chat_response=chat_response,
user_message_metadata={ user_message_metadata={
"created": user_message_time, "created": user_message_time,
"images": query_images, "images": query_images,
"turnId": turn_id,
}, },
khoj_message_metadata={ khoj_message_metadata={
"context": compiled_references, "context": compiled_references,
@ -246,6 +249,7 @@ def save_to_conversation_log(
"codeContext": code_results, "codeContext": code_results,
"automationId": automation_id, "automationId": automation_id,
"trainOfThought": train_of_thought, "trainOfThought": train_of_thought,
"turnId": turn_id,
}, },
conversation_log=meta_log.get("chat", []), conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought, train_of_thought=train_of_thought,

View file

@ -38,6 +38,7 @@ from khoj.routers.helpers import (
ChatRequestBody, ChatRequestBody,
CommonQueryParams, CommonQueryParams,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
DeleteMessageRequestBody,
agenerate_chat_response, agenerate_chat_response,
aget_relevant_information_sources, aget_relevant_information_sources,
aget_relevant_output_modes, aget_relevant_output_modes,
@ -534,6 +535,19 @@ async def set_conversation_title(
) )
@api_chat.delete("/conversation/message", response_class=Response)
@requires(["authenticated"])
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
user = request.user.object
success = ConversationAdapters.delete_message_by_turn_id(
user, delete_request.conversation_id, delete_request.turn_id
)
if success:
return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
else:
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
@api_chat.post("") @api_chat.post("")
@requires(["authenticated"]) @requires(["authenticated"])
async def chat( async def chat(

View file

@ -1264,6 +1264,11 @@ class ChatRequestBody(BaseModel):
create_new: Optional[bool] = False create_new: Optional[bool] = False
class DeleteMessageRequestBody(BaseModel):
conversation_id: str
turn_id: str
class ApiUserRateLimiter: class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests self.requests = requests