From 27b07d46b362ddd80f662aafd35dee39b41b667e Mon Sep 17 00:00:00 2001 From: Sean Hatfield <seanhatfield5@gmail.com> Date: Wed, 13 Nov 2024 12:34:42 -0800 Subject: [PATCH] Patch bad models endpoint path in LM Studio embedding engine (#2628) * patch bad models endpoint path in lm studio embedding engine * convert to OpenAI wrapper compatibility * add URL force parser/validation for LMStudio connections * remove comment --------- Co-authored-by: timothycarambat <rambat1010@gmail.com> --- server/utils/AiProviders/lmStudio/index.js | 19 ++++++++- .../utils/EmbeddingEngines/lmstudio/index.js | 42 ++++++++++--------- .../agents/aibitat/providers/ai-provider.js | 3 +- .../agents/aibitat/providers/lmstudio.js | 5 ++- server/utils/helpers/customModels.js | 5 ++- 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js index f548adcbc..6a30ef742 100644 --- a/server/utils/AiProviders/lmStudio/index.js +++ b/server/utils/AiProviders/lmStudio/index.js @@ -11,7 +11,7 @@ class LMStudioLLM { const { OpenAI: OpenAIApi } = require("openai"); this.lmstudio = new OpenAIApi({ - baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance + baseURL: parseLMStudioBasePath(process.env.LMSTUDIO_BASE_PATH), // here is the URL to your LMStudio instance apiKey: null, }); @@ -173,6 +173,23 @@ class LMStudioLLM { } } +/** + * Parse the base path for the LMStudio API. Since the base path must end in /v1 and cannot have a trailing slash, + * and the user can possibly set it to anything and likely incorrectly due to pasting behaviors, we need to ensure it is in the correct format. + * @param {string} basePath + * @returns {string} + */ +function parseLMStudioBasePath(providedBasePath = "") { + try { + const baseURL = new URL(providedBasePath); + const basePath = `${baseURL.origin}/v1`; + return basePath; + } catch (e) { + return providedBasePath; + } +} + module.exports = { LMStudioLLM, + parseLMStudioBasePath, }; diff --git a/server/utils/EmbeddingEngines/lmstudio/index.js b/server/utils/EmbeddingEngines/lmstudio/index.js index 6874b4b32..b03093aeb 100644 --- a/server/utils/EmbeddingEngines/lmstudio/index.js +++ b/server/utils/EmbeddingEngines/lmstudio/index.js @@ -1,3 +1,4 @@ +const { parseLMStudioBasePath } = require("../../AiProviders/lmStudio"); const { maximumChunkLength } = require("../../helpers"); class LMStudioEmbedder { @@ -6,10 +7,14 @@ class LMStudioEmbedder { throw new Error("No embedding base path was set."); if (!process.env.EMBEDDING_MODEL_PREF) throw new Error("No embedding model was set."); - this.basePath = `${process.env.EMBEDDING_BASE_PATH}/embeddings`; + + const { OpenAI: OpenAIApi } = require("openai"); + this.lmstudio = new OpenAIApi({ + baseURL: parseLMStudioBasePath(process.env.EMBEDDING_BASE_PATH), + apiKey: null, + }); this.model = process.env.EMBEDDING_MODEL_PREF; - // Limit of how many strings we can process in a single pass to stay with resource or network limits // Limit of how many strings we can process in a single pass to stay with resource or network limits this.maxConcurrentChunks = 1; this.embeddingMaxChunkLength = maximumChunkLength(); @@ -20,10 +25,9 @@ class LMStudioEmbedder { } async #isAlive() { - return await fetch(`${this.basePath}/models`, { - method: "HEAD", - }) - .then((res) => res.ok) + return await this.lmstudio.models + .list() + .then((res) => res?.data?.length > 0) .catch((e) => { this.log(e.message); return false; @@ -55,29 +59,29 @@ class LMStudioEmbedder { for (const chunk of textChunks) { if (hasError) break; // If an error occurred don't continue and exit early. results.push( - await fetch(this.basePath, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ + await this.lmstudio.embeddings + .create({ model: this.model, input: chunk, - }), - }) - .then((res) => res.json()) - .then((json) => { - const embedding = json.data[0].embedding; + }) + .then((result) => { + const embedding = result.data?.[0]?.embedding; if (!Array.isArray(embedding) || !embedding.length) throw { type: "EMPTY_ARR", message: "The embedding was empty from LMStudio", }; + console.log(`Embedding length: ${embedding.length}`); return { data: embedding, error: null }; }) - .catch((error) => { + .catch((e) => { + e.type = + e?.response?.data?.error?.code || + e?.response?.status || + "failed_to_embed"; + e.message = e?.response?.data?.error?.message || e.message; hasError = true; - return { data: [], error }; + return { data: [], error: e }; }) ); } diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index 1bbf4a0a4..4ba6840d7 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -16,6 +16,7 @@ const { ChatBedrockConverse } = require("@langchain/aws"); const { ChatOllama } = require("@langchain/community/chat_models/ollama"); const { toValidNumber } = require("../../../http"); const { getLLMProviderClass } = require("../../../helpers"); +const { parseLMStudioBasePath } = require("../../../AiProviders/lmStudio"); const DEFAULT_WORKSPACE_PROMPT = "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions."; @@ -169,7 +170,7 @@ class Provider { case "lmstudio": return new ChatOpenAI({ configuration: { - baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), + baseURL: parseLMStudioBasePath(process.env.LMSTUDIO_BASE_PATH), }, apiKey: "not-used", // Needs to be specified or else will assume OpenAI ...config, diff --git a/server/utils/agents/aibitat/providers/lmstudio.js b/server/utils/agents/aibitat/providers/lmstudio.js index c8f7c9108..ec9be89cb 100644 --- a/server/utils/agents/aibitat/providers/lmstudio.js +++ b/server/utils/agents/aibitat/providers/lmstudio.js @@ -2,6 +2,9 @@ const OpenAI = require("openai"); const Provider = require("./ai-provider.js"); const InheritMultiple = require("./helpers/classes.js"); const UnTooled = require("./helpers/untooled.js"); +const { + parseLMStudioBasePath, +} = require("../../../AiProviders/lmStudio/index.js"); /** * The agent provider for the LMStudio. @@ -18,7 +21,7 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { const model = config?.model || process.env.LMSTUDIO_MODEL_PREF || "Loaded from Chat UI"; const client = new OpenAI({ - baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance + baseURL: parseLMStudioBasePath(process.env.LMSTUDIO_BASE_PATH), apiKey: null, maxRetries: 3, }); diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index 163933769..72882d6d1 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -5,6 +5,7 @@ const { togetherAiModels } = require("../AiProviders/togetherAi"); const { fireworksAiModels } = require("../AiProviders/fireworksAi"); const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs"); const { fetchNovitaModels } = require("../AiProviders/novita"); +const { parseLMStudioBasePath } = require("../AiProviders/lmStudio"); const SUPPORT_CUSTOM_MODELS = [ "openai", "localai", @@ -235,7 +236,9 @@ async function getLMStudioModels(basePath = null) { try { const { OpenAI: OpenAIApi } = require("openai"); const openai = new OpenAIApi({ - baseURL: basePath || process.env.LMSTUDIO_BASE_PATH, + baseURL: parseLMStudioBasePath( + basePath || process.env.LMSTUDIO_BASE_PATH + ), apiKey: null, }); const models = await openai.models