summaryrefslogtreecommitdiff
path: root/backend/instance/chat.ts
diff options
context:
space:
mode:
Diffstat (limited to 'backend/instance/chat.ts')
-rw-r--r--backend/instance/chat.ts322
1 files changed, 322 insertions, 0 deletions
diff --git a/backend/instance/chat.ts b/backend/instance/chat.ts
new file mode 100644
index 0000000..c0cef4f
--- /dev/null
+++ b/backend/instance/chat.ts
@@ -0,0 +1,322 @@
+import fs from 'fs';
+import crypto from 'crypto';
+import { LlamaTokenizer } from 'llama-tokenizer-js';
+
+import { ChatMsg, ChatOptions, ChatStream, ModelInfo } from './api-defs.js';
+import * as api from './api-connection.js';
+import * as notes from './notes.js';
+import * as tools from './tools.js';
+import * as u from './utils.js';
+
+export type Msg = ChatMsg;
+export type Stream = ChatStream;
+export type Options = ChatOptions;
+
+let CONTEXT_WINDOW = 12000;
+let CHAT_DIR = `../data/chats`;
+
+const tokenizer = new LlamaTokenizer();
+
+export async function setConfig( config: any ) {
+ if( config.contextWindow ) CONTEXT_WINDOW = config.contextWindow;
+ if( config.chatDir ) CHAT_DIR = config.chatDir;
+}
+
+const TITLE_GEN_MODEL = "qwen25-custom-1b";
+
+async function generateTitle( firstMsg: Msg, response: string ) {
+ const prompt =
+`Your purpose is to generate a title for a chat given the beginning of a conversation between a user and a chatbot.
+The title should be a short summary of the topic, no longer than 50 characters.
+Output the title between <TITLE></TITLE> tags.
+
+<MESSAGE_LIST>
+user: ${firstMsg.content}
+
+chatbot: ${response}
+</MESSAGE_LIST>
+
+`
+ const model = {
+ modelname: TITLE_GEN_MODEL
+ } as ModelInfo;
+
+ let res = await generate( prompt, '', { model } as Options );
+ const eraselist = [
+ "<TITLE>",
+ "</TITLE>",
+ "<title>",
+ "</title>",
+ "title: "
+ ];
+ for( let key of eraselist )
+ res = res.replace( key, "" );
+
+ return res;
+}
+
+export async function run(
+ msgs: Msg[],
+ options: Options,
+ ignoreOutput: boolean,
+ notelist: notes.Note[],
+ onChunk: Function
+) : Promise<Msg> {
+ const ctx = parseMsgs( msgs, options, notelist );
+ const body = {
+ model: options.model.modelname,
+ messages: ctx
+ };
+
+ const res = await fetch( "http://127.0.0.1:11434/api/chat", {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify( body )
+ } );
+
+ if( !res.ok ) {
+ console.error( res );
+ throw new Error( "failed to receive response " + res.status );
+ }
+
+ const reader = res.body?.getReader();
+ if( !reader ) throw new Error( 'cannot get reader' );
+
+ api.serverNotify( { loadedModel: options.model.name } );
+ let content = ''
+ let toolBuffer = '';
+ let toolCall: tools.Call | undefined = undefined;
+ /** this is messy as fuck but u cant pass req body as param so whatever */
+ for( let read = await reader.read(); !read.done; read = await reader.read() ) {
+ const parsed = u.parseChunkedJson( read, ( json: any ) : boolean | void => {
+ if( !json.message ) return true;
+ let msg = json.message.content;
+ toolBuffer += msg;
+ if( tools.isToolStr( toolBuffer ) ) {
+ try {
+ toolCall = JSON.parse( toolBuffer );
+ if( toolCall ) {
+ content += toolBuffer;
+ return true;
+ }
+ } catch( e ) {}
+ return;
+ }
+
+ msg = toolBuffer;
+ toolBuffer = '';
+ content += msg;
+
+ if( !ignoreOutput ) {
+ process.stdout.write( msg );
+ onChunk( msg );
+ }
+ } );
+
+ if( !parsed || read.done )
+ break;
+ }
+
+ console.log();
+ const ret: Msg = {
+ timestamp: u.getTimestamp(),
+ role: 'assistant',
+ content,
+ toolCall
+ };
+
+ if( options.generateTitle && msgs.length == 1 ) {
+ // "not busy" notif sent by generate func
+ const title = await generateTitle( msgs[0], content );
+ ret.title = title;
+ }
+ else {
+ api.serverNotify( { loadedModel: options.model.name, isBusy: false } );
+ }
+ return ret;
+}
+
+export async function generate( prompt: string, suffix: string, options: Options, onChunk: Function = () => {} ) : Promise<string> {
+ const body = {
+ model: options.model.modelname,
+ system: options.system ? getFullSystem( options, [] ) : '',
+ prompt,
+ suffix
+ };
+
+ const res = await fetch( `http://127.0.0.1:11434/api/generate`, {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify( body )
+ } );
+
+ if( !res.ok ) {
+ console.error( res );
+ if( res.body )
+ console.log( await res.text() );
+ throw new Error( "failed to receive response" + res.status );
+ }
+
+ const reader = res.body?.getReader();
+ if( !reader ) throw new Error( 'cannot get reader' );
+
+ api.serverNotify( { loadedModel: options.model.name, isBusy: true } );
+ let content = '';
+ for( let read = await reader.read(); !read.done; read = await reader.read() ) {
+ const parsed = u.parseChunkedJson( read, ( json: any ) : boolean | void => {
+ let msg = json.response;
+ if( !msg ) return true;
+ content += msg;
+ process.stdout.write( msg );
+ onChunk( msg );
+ } );
+
+ if( !parsed || read.done )
+ break;
+ }
+ api.serverNotify( { loadedModel: options.model.name, isBusy: false } );
+
+ console.log();
+ return content;
+}
+
+export function save( msglog: Msg[], uuid: string ) {
+ const chatfile = `${CHAT_DIR}/${uuid}.json`;
+ const chatJson = JSON.stringify( msglog );
+
+ const salt = crypto.randomBytes( 16 );
+ const key = crypto.pbkdf2Sync( u.jwt_secret(), salt, 100000, 32, 'sha512' );
+ const iv = crypto.randomBytes( 12 );
+
+ const cipher = crypto.createCipheriv( 'aes-256-gcm', key, iv );
+ let encrypted = cipher.update( chatJson, 'utf8', 'hex' );
+ encrypted += cipher.final( 'hex' );
+ const authTag = cipher.getAuthTag().toString( 'hex' );
+
+ let fullStr = salt.toString( 'hex' ) + iv.toString( 'hex' ) + authTag + encrypted;
+
+ fs.writeFileSync( chatfile, fullStr );
+ console.log( "== [ chat saved ] ==" );
+}
+
+export function load( filename: string ) : Msg[] {
+ const chatfile = `${CHAT_DIR}/${filename}.json`;
+ try {
+ const contents = fs.readFileSync( chatfile, 'utf8' );
+ const salt = Buffer.from( contents.slice( 0, 32 ), 'hex' );
+ const iv = Buffer.from( contents.slice( 32, 56 ), 'hex' );
+ const authTag = Buffer.from( contents.slice( 56, 88 ), 'hex' );
+ const data = contents.slice( 88 );
+
+ const key = crypto.pbkdf2Sync( u.jwt_secret(), salt, 100000, 32, 'sha512' );
+ const decipher = crypto.createDecipheriv( 'aes-256-gcm', key, iv );
+ decipher.setAuthTag( authTag );
+
+ let decrypted = decipher.update( data, 'hex', 'utf8' );
+ decrypted += decipher.final( 'utf8' );
+
+ return JSON.parse( decrypted );
+ } catch( e ) {
+ return [];
+ }
+}
+
+export function parseMsgs( msgs: Msg[], options: Options, notelist: notes.Note[] ) : Msg[] {
+ let padTokens = 0;
+ let fullSystem = '';
+
+ if( options.system ) {
+ fullSystem = getFullSystem( options, notelist );
+ }
+
+ let res = truncateMsgs( msgs, padTokens );
+
+ if( fullSystem.length > 1 ) {
+ res.unshift( {
+ timestamp: u.getTimestamp(),
+ role: 'system',
+ content: fullSystem
+ } );
+ }
+
+ return res;
+}
+
+export function getFullSystem( options: Options, notelist: notes.Note[] ) : string {
+ if( !options.system )
+ return '';
+
+ let notesPrompt = notes.getPromptStr( notelist );
+ let fullSystem = '';
+ if( options.system.model )
+ fullSystem += options.system.model;
+ if( options.system.user )
+ fullSystem += options.system.user;
+
+ fullSystem = fullSystem.replace( '<|system_time|>', `the current system time is ${u.getTimestamp()}.` );
+ fullSystem = fullSystem.replace( '<|tools_list|>', tools.getPromptStr( options ) );
+ fullSystem = fullSystem.replace( '<|notes_str|>', notesPrompt );
+
+ return fullSystem;
+};
+
+function parseMsgFiles( msg: Msg ) : string {
+ if( !msg.files || !msg.files.length ) return '';
+
+ let files = msg.files;
+ let attachmentStr = 'attached files:';
+ for( let [key, f] of Object.entries( files ) ) {
+ if( f.type != 'text' )
+ continue;
+
+ attachmentStr += `\n\n[file: ${f.name}]\n`;
+ attachmentStr += f.content;
+ }
+
+ return attachmentStr;
+}
+
+function truncateMsgs( msgs: Msg[], reservedTokens: number ) : Msg[] {
+ let maxTokens = CONTEXT_WINDOW - 768 - reservedTokens;
+ let totalTokens = 0;
+ let totalLength = 0;
+ let loopEnd = false;
+
+ let start = Date.now();
+
+ let ret: Msg[] = [];
+ for( let i = msgs.length - 1; i >= 0; --i ) {
+ // javascript is fucking stupid.
+ let msg = JSON.parse( JSON.stringify( msgs[i] ) );
+ let content = msg.content + parseMsgFiles( msg );
+
+ let tokens = tokenizer.encode( content );
+
+ if( totalTokens + tokens.length > maxTokens ) {
+ if( content.length <= 20 )
+ break;
+
+ let diff = totalTokens + tokens.length + 10 - maxTokens;
+ let lenPercentage = diff / ( tokens.length + 10 );
+
+ let newLen = Math.floor( content.length * lenPercentage );
+ content = "<MESSAGE TRUNCATED>..." + content.slice( content.length - newLen );
+ loopEnd = true;
+ }
+
+ totalLength += content.length;
+ totalTokens += tokens.length + 2;
+
+ delete msg.files;
+ msg.content = content;
+ ret.unshift( msg );
+ if( loopEnd )
+ break;
+ }
+
+ let end = Date.now();
+ console.log( `tokenized ${ret.length} messages in ${end - start}ms` );
+
+ return ret;
+}
+