From 2a1202de548e8857764f0f129169a0dfd911c5c7 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Thu, 28 Dec 2023 13:59:47 -0800
Subject: [PATCH] Patch Ollama Streaming chunk issues (#500)

Replace stream/sync chats with Langchain interface for now
connect #499
ref: https://github.com/Mintplex-Labs/anything-llm/issues/495#issuecomment-1871476091
---
 .vscode/settings.json                    |   1 +
 server/utils/AiProviders/ollama/index.js | 175 ++++++++++-------------
 server/utils/chats/stream.js             |  35 +----
 3 files changed, 79 insertions(+), 132 deletions(-)

diff --git a/.vscode/settings.json b/.vscode/settings.json
index 2e43b1926..e6e720a96 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -1,6 +1,7 @@
 {
   "cSpell.words": [
     "Dockerized",
+    "Langchain",
     "Ollama",
     "openai",
     "Qdrant",
diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js
index 3aa58f760..f160e5d36 100644
--- a/server/utils/AiProviders/ollama/index.js
+++ b/server/utils/AiProviders/ollama/index.js
@@ -1,4 +1,5 @@
 const { chatPrompt } = require("../../chats");
+const { StringOutputParser } = require("langchain/schema/output_parser");
 
 // Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md
 class OllamaAILLM {
@@ -21,6 +22,42 @@ class OllamaAILLM {
     this.embedder = embedder;
   }
 
+  #ollamaClient({ temperature = 0.07 }) {
+    const { ChatOllama } = require("langchain/chat_models/ollama");
+    return new ChatOllama({
+      baseUrl: this.basePath,
+      model: this.model,
+      temperature,
+    });
+  }
+
+  // For streaming we use Langchain's wrapper to handle weird chunks
+  // or otherwise absorb headaches that can arise from Ollama models
+  #convertToLangchainPrototypes(chats = []) {
+    const {
+      HumanMessage,
+      SystemMessage,
+      AIMessage,
+    } = require("langchain/schema");
+    const langchainChats = [];
+    for (const chat of chats) {
+      switch (chat.role) {
+        case "system":
+          langchainChats.push(new SystemMessage({ content: chat.content }));
+          break;
+        case "user":
+          langchainChats.push(new HumanMessage({ content: chat.content }));
+          break;
+        case "assistant":
+          langchainChats.push(new AIMessage({ content: chat.content }));
+          break;
+        default:
+          break;
+      }
+    }
+    return langchainChats;
+  }
+
   streamingEnabled() {
     return "streamChat" in this && "streamGetChatCompletion" in this;
   }
@@ -63,37 +100,21 @@ Context:
   }
 
   async sendChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
-    const textResponse = await fetch(`${this.basePath}/api/chat`, {
-      method: "POST",
-      headers: {
-        "Content-Type": "application/json",
+    const messages = await this.compressMessages(
+      {
+        systemPrompt: chatPrompt(workspace),
+        userPrompt: prompt,
+        chatHistory,
       },
-      body: JSON.stringify({
-        model: this.model,
-        stream: false,
-        options: {
-          temperature: Number(workspace?.openAiTemp ?? 0.7),
-        },
-        messages: await this.compressMessages(
-          {
-            systemPrompt: chatPrompt(workspace),
-            userPrompt: prompt,
-            chatHistory,
-          },
-          rawHistory
-        ),
-      }),
-    })
-      .then((res) => {
-        if (!res.ok)
-          throw new Error(`Ollama:sendChat ${res.status} ${res.statusText}`);
-        return res.json();
-      })
-      .then((data) => data?.message?.content)
-      .catch((e) => {
-        console.error(e);
-        throw new Error(`Ollama::sendChat failed with: ${error.message}`);
-      });
+      rawHistory
+    );
+
+    const model = this.#ollamaClient({
+      temperature: Number(workspace?.openAiTemp ?? 0.7),
+    });
+    const textResponse = await model
+      .pipe(new StringOutputParser())
+      .invoke(this.#convertToLangchainPrototypes(messages));
 
     if (!textResponse.length)
       throw new Error(`Ollama::sendChat text response was empty.`);
@@ -102,63 +123,29 @@ Context:
   }
 
   async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
-    const response = await fetch(`${this.basePath}/api/chat`, {
-      method: "POST",
-      headers: {
-        "Content-Type": "application/json",
+    const messages = await this.compressMessages(
+      {
+        systemPrompt: chatPrompt(workspace),
+        userPrompt: prompt,
+        chatHistory,
       },
-      body: JSON.stringify({
-        model: this.model,
-        stream: true,
-        options: {
-          temperature: Number(workspace?.openAiTemp ?? 0.7),
-        },
-        messages: await this.compressMessages(
-          {
-            systemPrompt: chatPrompt(workspace),
-            userPrompt: prompt,
-            chatHistory,
-          },
-          rawHistory
-        ),
-      }),
-    }).catch((e) => {
-      console.error(e);
-      throw new Error(`Ollama:streamChat ${error.message}`);
-    });
+      rawHistory
+    );
 
-    return { type: "ollamaStream", response };
+    const model = this.#ollamaClient({
+      temperature: Number(workspace?.openAiTemp ?? 0.7),
+    });
+    const stream = await model
+      .pipe(new StringOutputParser())
+      .stream(this.#convertToLangchainPrototypes(messages));
+    return stream;
   }
 
   async getChatCompletion(messages = null, { temperature = 0.7 }) {
-    const textResponse = await fetch(`${this.basePath}/api/chat`, {
-      method: "POST",
-      headers: {
-        "Content-Type": "application/json",
-      },
-      body: JSON.stringify({
-        model: this.model,
-        messages,
-        stream: false,
-        options: {
-          temperature,
-        },
-      }),
-    })
-      .then((res) => {
-        if (!res.ok)
-          throw new Error(
-            `Ollama:getChatCompletion ${res.status} ${res.statusText}`
-          );
-        return res.json();
-      })
-      .then((data) => data?.message?.content)
-      .catch((e) => {
-        console.error(e);
-        throw new Error(
-          `Ollama::getChatCompletion failed with: ${error.message}`
-        );
-      });
+    const model = this.#ollamaClient({ temperature });
+    const textResponse = await model
+      .pipe(new StringOutputParser())
+      .invoke(this.#convertToLangchainPrototypes(messages));
 
     if (!textResponse.length)
       throw new Error(`Ollama::getChatCompletion text response was empty.`);
@@ -167,25 +154,11 @@ Context:
   }
 
   async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
-    const response = await fetch(`${this.basePath}/api/chat`, {
-      method: "POST",
-      headers: {
-        "Content-Type": "application/json",
-      },
-      body: JSON.stringify({
-        model: this.model,
-        stream: true,
-        messages,
-        options: {
-          temperature,
-        },
-      }),
-    }).catch((e) => {
-      console.error(e);
-      throw new Error(`Ollama:streamGetChatCompletion ${error.message}`);
-    });
-
-    return { type: "ollamaStream", response };
+    const model = this.#ollamaClient({ temperature });
+    const stream = await model
+      .pipe(new StringOutputParser())
+      .stream(this.#convertToLangchainPrototypes(messages));
+    return stream;
   }
 
   // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js
index b0dc9186b..240e4a173 100644
--- a/server/utils/chats/stream.js
+++ b/server/utils/chats/stream.js
@@ -232,46 +232,19 @@ function handleStreamResponses(response, stream, responseProps) {
     });
   }
 
-  if (stream?.type === "ollamaStream") {
-    return new Promise(async (resolve) => {
-      let fullText = "";
-      for await (const dataChunk of stream.response.body) {
-        const chunk = JSON.parse(Buffer.from(dataChunk).toString());
-        fullText += chunk.message.content;
-        writeResponseChunk(response, {
-          uuid,
-          sources: [],
-          type: "textResponseChunk",
-          textResponse: chunk.message.content,
-          close: false,
-          error: false,
-        });
-      }
-
-      writeResponseChunk(response, {
-        uuid,
-        sources,
-        type: "textResponseChunk",
-        textResponse: "",
-        close: true,
-        error: false,
-      });
-      resolve(fullText);
-    });
-  }
-
-  // If stream is not a regular OpenAI Stream (like if using native model)
+  // If stream is not a regular OpenAI Stream (like if using native model, Ollama, or most LangChain interfaces)
   // we can just iterate the stream content instead.
   if (!stream.hasOwnProperty("data")) {
     return new Promise(async (resolve) => {
       let fullText = "";
       for await (const chunk of stream) {
-        fullText += chunk.content;
+        const content = chunk.hasOwnProperty("content") ? chunk.content : chunk;
+        fullText += content;
         writeResponseChunk(response, {
           uuid,
           sources: [],
           type: "textResponseChunk",
-          textResponse: chunk.content,
+          textResponse: content,
           close: false,
           error: false,
         });