Agent Context window + context window refactor. ()

* Enable agent context windows to be accurate per provider:model

* Refactor model mapping to external file
Add token count to document length instead of char-count
refernce promptWindowLimit from AIProvider in central location

* remove unused imports
This commit is contained in:
Timothy Carambat 2024-08-15 12:13:28 -07:00 committed by GitHub
parent 4365d69359
commit 99f2c25b1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 284 additions and 94 deletions
server/utils
AiProviders
anthropic
azureOpenAi
bedrock
cohere
gemini
genericOpenAi
groq
huggingface
koboldCPP
liteLLM
lmStudio
localAi
mistral
modelMap.js
native
ollama
openAi
openRouter
perplexity
textGenWebUI
togetherAi
agents/aibitat
helpers

View file

@ -4,6 +4,7 @@ const {
clientAbortedHandler,
} = require("../../helpers/chat/responses");
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const { MODEL_MAP } = require("../modelMap");
class AnthropicLLM {
constructor(embedder = null, modelPreference = null) {
@ -32,25 +33,12 @@ class AnthropicLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(modelName) {
return MODEL_MAP.anthropic[modelName] ?? 100_000;
}
promptWindowLimit() {
switch (this.model) {
case "claude-instant-1.2":
return 100_000;
case "claude-2.0":
return 100_000;
case "claude-2.1":
return 200_000;
case "claude-3-opus-20240229":
return 200_000;
case "claude-3-sonnet-20240229":
return 200_000;
case "claude-3-haiku-20240307":
return 200_000;
case "claude-3-5-sonnet-20240620":
return 200_000;
default:
return 100_000; // assume a claude-instant-1.2 model
}
return MODEL_MAP.anthropic[this.model] ?? 100_000;
}
isValidChatCompletionModel(modelName = "") {

View file

@ -43,6 +43,12 @@ class AzureOpenAiLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
return !!process.env.AZURE_OPENAI_TOKEN_LIMIT
? Number(process.env.AZURE_OPENAI_TOKEN_LIMIT)
: 4096;
}
// Sure the user selected a proper value for the token limit
// could be any of these https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-models
// and if undefined - assume it is the lowest end.

View file

@ -82,6 +82,13 @@ class AWSBedrockLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT || 8191;
if (!limit || isNaN(Number(limit)))
throw new Error("No valid token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -1,6 +1,7 @@
const { v4 } = require("uuid");
const { writeResponseChunk } = require("../../helpers/chat/responses");
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const { MODEL_MAP } = require("../modelMap");
class CohereLLM {
constructor(embedder = null) {
@ -58,23 +59,12 @@ class CohereLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(modelName) {
return MODEL_MAP.cohere[modelName] ?? 4_096;
}
promptWindowLimit() {
switch (this.model) {
case "command-r":
return 128_000;
case "command-r-plus":
return 128_000;
case "command":
return 4_096;
case "command-light":
return 4_096;
case "command-nightly":
return 8_192;
case "command-light-nightly":
return 8_192;
default:
return 4_096;
}
return MODEL_MAP.cohere[this.model] ?? 4_096;
}
async isValidChatCompletionModel(model = "") {

View file

@ -3,6 +3,7 @@ const {
writeResponseChunk,
clientAbortedHandler,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
class GeminiLLM {
constructor(embedder = null, modelPreference = null) {
@ -89,21 +90,12 @@ class GeminiLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(modelName) {
return MODEL_MAP.gemini[modelName] ?? 30_720;
}
promptWindowLimit() {
switch (this.model) {
case "gemini-pro":
return 30_720;
case "gemini-1.0-pro":
return 30_720;
case "gemini-1.5-flash-latest":
return 1_048_576;
case "gemini-1.5-pro-latest":
return 2_097_152;
case "gemini-1.5-pro-exp-0801":
return 2_097_152;
default:
return 30_720; // assume a gemini-pro model
}
return MODEL_MAP.gemini[this.model] ?? 30_720;
}
isValidChatCompletionModel(modelName = "") {

View file

@ -55,6 +55,13 @@ class GenericOpenAiLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.GENERIC_OPEN_AI_MODEL_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const {
handleDefaultStreamResponseV2,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
class GroqLLM {
constructor(embedder = null, modelPreference = null) {
@ -40,21 +41,12 @@ class GroqLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(modelName) {
return MODEL_MAP.groq[modelName] ?? 8192;
}
promptWindowLimit() {
switch (this.model) {
case "gemma2-9b-it":
case "gemma-7b-it":
case "llama3-70b-8192":
case "llama3-8b-8192":
return 8192;
case "llama-3.1-70b-versatile":
case "llama-3.1-8b-instant":
return 8000;
case "mixtral-8x7b-32768":
return 32768;
default:
return 8192;
}
return MODEL_MAP.groq[this.model] ?? 8192;
}
async isValidChatCompletionModel(modelName = "") {

View file

@ -45,6 +45,13 @@ class HuggingFaceLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.HUGGING_FACE_LLM_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No HuggingFace token context limit was set.");
return Number(limit);
}
promptWindowLimit() {
const limit = process.env.HUGGING_FACE_LLM_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))

View file

@ -51,6 +51,13 @@ class KoboldCPPLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.KOBOLD_CPP_MODEL_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -50,6 +50,13 @@ class LiteLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.LITE_LLM_MODEL_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -48,6 +48,13 @@ class LMStudioLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.LMSTUDIO_MODEL_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No LMStudio token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -40,6 +40,13 @@ class LocalAiLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.LOCAL_AI_MODEL_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No LocalAi token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -41,6 +41,10 @@ class MistralLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit() {
return 32000;
}
promptWindowLimit() {
return 32000;
}

View file

@ -0,0 +1,55 @@
/**
* The model name and context window for all know model windows
* that are available through providers which has discrete model options.
*/
const MODEL_MAP = {
anthropic: {
"claude-instant-1.2": 100_000,
"claude-2.0": 100_000,
"claude-2.1": 200_000,
"claude-3-opus-20240229": 200_000,
"claude-3-sonnet-20240229": 200_000,
"claude-3-haiku-20240307": 200_000,
"claude-3-5-sonnet-20240620": 200_000,
},
cohere: {
"command-r": 128_000,
"command-r-plus": 128_000,
command: 4_096,
"command-light": 4_096,
"command-nightly": 8_192,
"command-light-nightly": 8_192,
},
gemini: {
"gemini-pro": 30_720,
"gemini-1.0-pro": 30_720,
"gemini-1.5-flash-latest": 1_048_576,
"gemini-1.5-pro-latest": 2_097_152,
"gemini-1.5-pro-exp-0801": 2_097_152,
},
groq: {
"gemma2-9b-it": 8192,
"gemma-7b-it": 8192,
"llama3-70b-8192": 8192,
"llama3-8b-8192": 8192,
"llama-3.1-70b-versatile": 8000,
"llama-3.1-8b-instant": 8000,
"mixtral-8x7b-32768": 32768,
},
openai: {
"gpt-3.5-turbo": 16_385,
"gpt-3.5-turbo-1106": 16_385,
"gpt-4o": 128_000,
"gpt-4o-2024-08-06": 128_000,
"gpt-4o-2024-05-13": 128_000,
"gpt-4o-mini": 128_000,
"gpt-4o-mini-2024-07-18": 128_000,
"gpt-4-turbo": 128_000,
"gpt-4-1106-preview": 128_000,
"gpt-4-turbo-preview": 128_000,
"gpt-4": 8_192,
"gpt-4-32k": 32_000,
},
};
module.exports = { MODEL_MAP };

View file

@ -96,6 +96,13 @@ class NativeLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No NativeAI token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
promptWindowLimit() {
const limit = process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT || 4096;

View file

@ -82,6 +82,13 @@ class OllamaAILLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.OLLAMA_MODEL_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No Ollama token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const {
handleDefaultStreamResponseV2,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
class OpenAiLLM {
constructor(embedder = null, modelPreference = null) {
@ -38,27 +39,12 @@ class OpenAiLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(modelName) {
return MODEL_MAP.openai[modelName] ?? 4_096;
}
promptWindowLimit() {
switch (this.model) {
case "gpt-3.5-turbo":
case "gpt-3.5-turbo-1106":
return 16_385;
case "gpt-4o":
case "gpt-4o-2024-08-06":
case "gpt-4o-2024-05-13":
case "gpt-4o-mini":
case "gpt-4o-mini-2024-07-18":
case "gpt-4-turbo":
case "gpt-4-1106-preview":
case "gpt-4-turbo-preview":
return 128_000;
case "gpt-4":
return 8_192;
case "gpt-4-32k":
return 32_000;
default:
return 4_096; // assume a fine-tune 3.5?
}
return MODEL_MAP.openai[this.model] ?? 4_096;
}
// Short circuit if name has 'gpt' since we now fetch models from OpenAI API

View file

@ -117,6 +117,17 @@ class OpenRouterLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(modelName) {
const cacheModelPath = path.resolve(cacheFolder, "models.json");
const availableModels = fs.existsSync(cacheModelPath)
? safeJsonParse(
fs.readFileSync(cacheModelPath, { encoding: "utf-8" }),
{}
)
: {};
return availableModels[modelName]?.maxLength || 4096;
}
promptWindowLimit() {
const availableModels = this.models();
return availableModels[this.model]?.maxLength || 4096;

View file

@ -52,6 +52,11 @@ class PerplexityLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(modelName) {
const availableModels = perplexityModels();
return availableModels[modelName]?.maxLength || 4096;
}
promptWindowLimit() {
const availableModels = this.allModelInformation();
return availableModels[this.model]?.maxLength || 4096;

View file

@ -48,6 +48,13 @@ class TextGenWebUILLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(_modelName) {
const limit = process.env.TEXT_GEN_WEB_UI_MODEL_TOKEN_LIMIT || 4096;
if (!limit || isNaN(Number(limit)))
throw new Error("No token context limit was set.");
return Number(limit);
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -48,6 +48,11 @@ class TogetherAiLLM {
return "streamGetChatCompletion" in this;
}
static promptWindowLimit(modelName) {
const availableModels = togetherAiModels();
return availableModels[modelName]?.maxLength || 4096;
}
// Ensure the user set a value for the token limit
// and if undefined - assume 4096 window.
promptWindowLimit() {

View file

@ -136,9 +136,11 @@ const docSummarizer = {
);
}
const { TokenManager } = require("../../../helpers/tiktoken");
if (
document.content?.length <
Provider.contextLimit(this.super.provider)
new TokenManager(this.super.model).countFromString(
document.content
) < Provider.contextLimit(this.super.provider, this.super.model)
) {
return document.content;
}

View file

@ -77,7 +77,11 @@ const webScraping = {
throw new Error("There was no content to be collected or read.");
}
if (content.length < Provider.contextLimit(this.super.provider)) {
const { TokenManager } = require("../../../helpers/tiktoken");
if (
new TokenManager(this.super.model).countFromString(content) <
Provider.contextLimit(this.super.provider, this.super.model)
) {
return content;
}

View file

@ -15,6 +15,7 @@ const { ChatAnthropic } = require("@langchain/anthropic");
const { ChatBedrockConverse } = require("@langchain/aws");
const { ChatOllama } = require("@langchain/community/chat_models/ollama");
const { toValidNumber } = require("../../../http");
const { getLLMProviderClass } = require("../../../helpers");
const DEFAULT_WORKSPACE_PROMPT =
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
@ -173,15 +174,16 @@ class Provider {
}
}
static contextLimit(provider = "openai") {
switch (provider) {
case "openai":
return 8_000;
case "anthropic":
return 100_000;
default:
return 8_000;
}
/**
* Get the context limit for a provider/model combination using static method in AIProvider class.
* @param {string} provider
* @param {string} modelName
* @returns {number}
*/
static contextLimit(provider = "openai", modelName) {
const llm = getLLMProviderClass({ provider });
if (!llm || !llm.hasOwnProperty("promptWindowLimit")) return 8_000;
return llm.promptWindowLimit(modelName);
}
// For some providers we may want to override the system prompt to be more verbose.

View file

@ -20,6 +20,11 @@
* @property {Function} compressMessages - Compresses chat messages to fit within the token limit.
*/
/**
* @typedef {Object} BaseLLMProviderClass - Class method of provider - not instantiated
* @property {function(string): number} promptWindowLimit - Returns the token limit for the provided model.
*/
/**
* @typedef {Object} BaseVectorDatabaseProvider
* @property {string} name - The name of the Vector Database instance.
@ -204,6 +209,78 @@ function getEmbeddingEngineSelection() {
}
}
/**
* Returns the LLMProviderClass - this is a helper method to access static methods on a class
* @param {{provider: string | null} | null} params - Initialize params for LLMs provider
* @returns {BaseLLMProviderClass}
*/
function getLLMProviderClass({ provider = null } = {}) {
switch (provider) {
case "openai":
const { OpenAiLLM } = require("../AiProviders/openAi");
return OpenAiLLM;
case "azure":
const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi");
return AzureOpenAiLLM;
case "anthropic":
const { AnthropicLLM } = require("../AiProviders/anthropic");
return AnthropicLLM;
case "gemini":
const { GeminiLLM } = require("../AiProviders/gemini");
return GeminiLLM;
case "lmstudio":
const { LMStudioLLM } = require("../AiProviders/lmStudio");
return LMStudioLLM;
case "localai":
const { LocalAiLLM } = require("../AiProviders/localAi");
return LocalAiLLM;
case "ollama":
const { OllamaAILLM } = require("../AiProviders/ollama");
return OllamaAILLM;
case "togetherai":
const { TogetherAiLLM } = require("../AiProviders/togetherAi");
return TogetherAiLLM;
case "perplexity":
const { PerplexityLLM } = require("../AiProviders/perplexity");
return PerplexityLLM;
case "openrouter":
const { OpenRouterLLM } = require("../AiProviders/openRouter");
return OpenRouterLLM;
case "mistral":
const { MistralLLM } = require("../AiProviders/mistral");
return MistralLLM;
case "native":
const { NativeLLM } = require("../AiProviders/native");
return NativeLLM;
case "huggingface":
const { HuggingFaceLLM } = require("../AiProviders/huggingface");
return HuggingFaceLLM;
case "groq":
const { GroqLLM } = require("../AiProviders/groq");
return GroqLLM;
case "koboldcpp":
const { KoboldCPPLLM } = require("../AiProviders/koboldCPP");
return KoboldCPPLLM;
case "textgenwebui":
const { TextGenWebUILLM } = require("../AiProviders/textGenWebUI");
return TextGenWebUILLM;
case "cohere":
const { CohereLLM } = require("../AiProviders/cohere");
return CohereLLM;
case "litellm":
const { LiteLLM } = require("../AiProviders/liteLLM");
return LiteLLM;
case "generic-openai":
const { GenericOpenAiLLM } = require("../AiProviders/genericOpenAi");
return GenericOpenAiLLM;
case "bedrock":
const { AWSBedrockLLM } = require("../AiProviders/bedrock");
return AWSBedrockLLM;
default:
return null;
}
}
// Some models have lower restrictions on chars that can be encoded in a single pass
// and by default we assume it can handle 1,000 chars, but some models use work with smaller
// chars so here we can override that value when embedding information.
@ -228,6 +305,7 @@ module.exports = {
getEmbeddingEngineSelection,
maximumChunkLength,
getVectorDbClass,
getLLMProviderClass,
getLLMProvider,
toChunks,
};