mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Store turn id with each chat message. Expose API to delete chat turn
Each chat turn is a user query, khoj response message pair
This commit is contained in:
parent
f64f5b3b6e
commit
ba15686682
4 changed files with 34 additions and 0 deletions
|
@ -1348,6 +1348,17 @@ class ConversationAdapters:
|
|||
conversation.save()
|
||||
return conversation.file_filters
|
||||
|
||||
@staticmethod
|
||||
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
|
||||
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
||||
if not conversation:
|
||||
return False
|
||||
conversation_log = conversation.conversation_log
|
||||
updated_log = [msg for msg in conversation_log if msg.get("turnId") != turn_id]
|
||||
conversation.conversation_log = updated_log
|
||||
conversation.save()
|
||||
return True
|
||||
|
||||
|
||||
class FileObjectAdapters:
|
||||
@staticmethod
|
||||
|
|
|
@ -4,6 +4,7 @@ import math
|
|||
import mimetypes
|
||||
import os
|
||||
import queue
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
@ -232,12 +233,14 @@ def save_to_conversation_log(
|
|||
train_of_thought: List[Any] = [],
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
turn_id = tracer.get("mid") or str(uuid.uuid4())
|
||||
updated_conversation = message_to_log(
|
||||
user_message=q,
|
||||
chat_response=chat_response,
|
||||
user_message_metadata={
|
||||
"created": user_message_time,
|
||||
"images": query_images,
|
||||
"turnId": turn_id,
|
||||
},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
|
@ -246,6 +249,7 @@ def save_to_conversation_log(
|
|||
"codeContext": code_results,
|
||||
"automationId": automation_id,
|
||||
"trainOfThought": train_of_thought,
|
||||
"turnId": turn_id,
|
||||
},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
train_of_thought=train_of_thought,
|
||||
|
|
|
@ -38,6 +38,7 @@ from khoj.routers.helpers import (
|
|||
ChatRequestBody,
|
||||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
DeleteMessageRequestBody,
|
||||
agenerate_chat_response,
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
|
@ -534,6 +535,19 @@ async def set_conversation_title(
|
|||
)
|
||||
|
||||
|
||||
@api_chat.delete("/conversation/message", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
|
||||
user = request.user.object
|
||||
success = ConversationAdapters.delete_message_by_turn_id(
|
||||
user, delete_request.conversation_id, delete_request.turn_id
|
||||
)
|
||||
if success:
|
||||
return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
|
||||
else:
|
||||
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
|
||||
|
||||
|
||||
@api_chat.post("")
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
|
|
|
@ -1264,6 +1264,11 @@ class ChatRequestBody(BaseModel):
|
|||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
class DeleteMessageRequestBody(BaseModel):
|
||||
conversation_id: str
|
||||
turn_id: str
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
|
|
Loading…
Reference in a new issue