diff options
Diffstat (limited to 'backend/api/src/chat.zig')
| -rw-r--r-- | backend/api/src/chat.zig | 586 |
1 files changed, 586 insertions, 0 deletions
diff --git a/backend/api/src/chat.zig b/backend/api/src/chat.zig new file mode 100644 index 0000000..1200562 --- /dev/null +++ b/backend/api/src/chat.zig @@ -0,0 +1,586 @@ +const z = @import( "std" ); +const zap = @import( "zap" ); +const api = @import( "api.zig" ); +const req = @import( "req.zig" ); +const net = @import( "net-util.zig" ); +const model = @import( "model.zig" ); + +const u = struct { + usingnamespace @import( "util.zig" ); + usingnamespace @import( "userdefs.zig" ); + usingnamespace @import( "user.zig" ); + usingnamespace @import( "userdata.zig" ); +}; + +// -------------------------------------------------------------------------------------------------------- +// some docs because this is becoming a little complex: +// the api is mostly a proxy between the user and model instances +// user sends ChatReq to api, api forwards a ModelInstanceChatReq to instance after verifying user acc etc +// instance replies with a ChatStream defined in typescript [api-defs.ts](file://../instance/api-defs.ts#L58) +// api simply forwards this ChatStream to the client, only parsing it to check for a title. +// if a title is present, it's updated in the user db entry. +// +// the situation looks similar for the generate endpoint except there are no title checks. +// the msg is simply proxied to the user. +// +// legend: +// (source) - message source +// [type] - message type +// {target} - message target +// @/ - route +// +// chat route: +// (user) [ChatReq] req @/chat --> (api) [ModelInstanceChatReq] req @/chat --> {instance} +// (instance) [ChatStream] res --> (api) proxy --> {user} +// +// generate route: +// (user) [GenerateReq] req @/generate --> (api) [ModelInstanceGenerateReq] req @/generate --> {instance} +// (instance) [undefined in zig] res --> (api) proxy --> {user} +// -------------------------------------------------------------------------------------------------------- + +const ArrayList = z.ArrayList; +const ErrRes = net.ErrorResponse; +const OkRes = net.OkResponse; +const Request = zap.Request; + +const memeql = z.mem.eql; + +const alloc = u.alloc; + +pub const title_gen_model = "qwen2.5-1.5b"; +pub const routes = .{ + .@"chat" = chat, + .@"get-chat" = getChat, + .@"generate" = generate, + .@"create-chat" = createChat, + .@"delete-chat" = deleteChat, + .@"get-chat-name" = getChatName +}; + +const ToolCall = struct { + name: []const u8, + parameters: struct { + pub const @"getty.db" = u.@"json.ignore.unknown"; + }// processed either by instance or client, we dont care abt the contents +}; + +const MsgFile = struct { + name: []const u8, + type: []const u8, + content: []const u8 +}; + +const Msg = struct { + content: []const u8, + role: []const u8, + timestamp: []const u8, + toolCall: ?ToolCall = null, + images: ?[][]const u8 = null, + files: ?[]MsgFile = null, + + pub const @"getty.db" = u.@"json.ignore.unknown"; +}; + +const ClientOptions = struct { + seed: ?u32, + temperature: ?f32, + + pub const @"getty.db" = u.@"json.ignore.unknown"; +}; + +const ChatReq = struct { + model: []const u8, + messages: []Msg, + system: ?[]const u8, + chatfile: ?[]const u8 = null, + options: ?ClientOptions = null, + generateTitle: ?bool = null, + + pub const @"getty.db" = u.@"json.ignore.unknown"; +}; + +const ChatStream = struct { + status: []const u8, + done: bool, + title: ?[]const u8, + + pub const @"getty.db" = u.@"json.ignore.unknown"; +}; + +const GenerateReq = struct { + model: []const u8, + prompt: []const u8, + suffix: ?[]const u8 = "", + options: ?ClientOptions = null, + + pub const @"getty.db" = u.@"json.ignore.unknown"; +}; + +const GetChatReq = struct { + chatId: []const u8, + + pub const @"getty.db" = u.@"json.ignore.unknown"; +}; + +const ChatMsgContext = struct { + uuid: []const u8, + chatfile: []const u8 +}; + + +// theres no reason for these to be split up other than the frontend being suck +// todo later: fix +const ModelInstanceChatOptions = struct { + system: ?struct { + model: ?[]const u8 = null, + user: ?[]const u8 = null, + }, + model: model.Model, + uuid: []const u8, + chatfile: ?[]const u8 = null, + generateTitle: ?bool = null, +}; + +const ModelInstanceChatReq = struct { + messages: []Msg, + options: ModelInstanceChatOptions, +}; + +const ModelInstanceGenerateOptions = struct { + model: model.Model, + system: ?struct { + model: ?[]const u8 = null, + user: ?[]const u8 = null, + } = null, +}; + +const ModelInstanceGenerateReq = struct { + options: ModelInstanceGenerateOptions, + prompt: []const u8, + suffix: ?[]const u8, +}; + +///forwards the instance error to client if reported +///sends a 500 otherwise +fn handleInstanceError( r: Request, q: *const req.Result ) void { + if( q.body ) |body| { + const msg = u.jsonParse( struct { status: []const u8, msg: []const u8 }, body ) catch null; + if( msg != null ) { + z.debug.print( "res err status [{any}]: {s}\n", .{ q.status, body } ); + r.sendChunk( body ) catch {}; + } + } else { + net.sendJsonChunk( r, .internal_server_error, ErrRes{ .msg = "could not contact server" } ); + } +} + + +// ------------------------------------------------------------------------------------------- +// chat -------------------------------------------------------------------------------------- + +fn chatChunkFn( chunk: []const u8, r: Request ) void { + const ctx = r.getUserContext( ChatMsgContext ); + var parts = z.mem.split( u8, chunk, "\n" ); + var part: ?[]const u8 = parts.first(); + while( part ) |p| : (part = parts.next()) { + if( p.len == 0 ) + continue; + + var buf = alloc.alloc( u8, p.len + 1 ) catch return; + defer alloc.free( buf ); + @memcpy( buf[0..p.len], p ); + buf[p.len] = '\n'; + + if( ctx ) |_ctx| { + if( u.jsonParse( ChatStream, buf ) catch null ) |parsed| { + if( parsed.v.done and parsed.v.title != null ) { + setChatTitle( _ctx.uuid, _ctx.chatfile, parsed.v.title.? ); + } + + parsed.deinit(); + } + } + + r.sendChunk( buf ) catch |e| { + z.debug.print( "error sending chunk: {any}\n", .{e} ); + }; + } +} + +fn startChat( r: Request, address: []const u8, params: ChatReq, uuid: []const u8, user: *u.UserEntry ) void { + const _model = model.getModelByDisplayName( params.model ) catch { + return net.sendJsonChunk( r, .bad_request, ErrRes{ .msg = "invalid model" } ); + }; + + if( !model.canBeUsedByUser( _model.*, user ) ) { + return net.sendJson( r, .unauthorized, ErrRes{ .msg = "user is not authorized to use this model" } ); + } + + const chat_req = ModelInstanceChatReq{ + .messages = params.messages, + .options = ModelInstanceChatOptions{ + .model = _model.*, + .system = .{ + .model = _model.system, + .user = params.system, + }, + .uuid = uuid, + .chatfile = params.chatfile, + .generateTitle = params.generateTitle + }, + }; + + if( params.chatfile ) |chatfile| { + var ctx = ChatMsgContext{ + .uuid = uuid, + .chatfile = chatfile + }; + r.setUserContext( &ctx ); + } + + const params_str = u.jsonStringify( chat_req ) catch { + return net.sendJson( r, .bad_request, ErrRes{ .msg = "invalid request format" } ); + }; defer alloc.free( params_str ); + + r.setChunked() catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "internal server error" } ); + }; + const q = req.sendWithData( Request, .{ + .method = .POST, + .url = address, + .headers = &[_]z.http.Header{ .{ .name = "Content-Type", .value = "application/json" } }, + .body = params_str, + .chunk_fn = chatChunkFn, + .chunk_data = r + }, alloc ); + defer q.deinit(); + + if( !q.ok ) + handleInstanceError( r, &q ); + r.endStream() catch |e| { + return z.debug.print( "error closing connection {any} {any}", .{e, @errorReturnTrace() } ); + }; +} + +///route @/chat +fn chat( r: Request ) void { + if( net.handleInvalidPostReq( r ) ) return; + u.checkDbOnThread() catch return; + + const uuid = u.uuidFromApiOrAuthToken( r ) catch return; + defer alloc.free( uuid ); + + const params = u.jsonParse( ChatReq, r.body.? ) catch |e| { + z.debug.print( "err: {any}\n", .{e} ); + return net.sendJson( r, .bad_request, ErrRes{ .msg = "invalid request format" } ); + }; defer params.deinit(); + + var user = u.getEntry( uuid ) catch { + return net.sendJson( r, .unauthorized, ErrRes{ .msg = "user not found" } ); + }; defer user.db_entry.?.deinit(); + + u.checkSubscription( &user ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "internal server error" } ); + }; + + var prefs = u.getSettings( uuid ) catch { + return net.sendJson( r, .unauthorized, ErrRes{ .msg = "error getting user data" } ); + }; defer prefs.db_entry.?.deinit(); + + if( params.v.chatfile != null and !u.hasChat( &prefs, params.v.chatfile.? ) ) + return net.sendJson( r, .bad_request, ErrRes{ .msg = "specified chat file does not exist" } ); + + const serv = api.getAvailableServer( params.v.model, alloc ) orelse { + return net.sendJson( r, .service_unavailable, ErrRes{ .msg = "no server available" } ); + }; defer alloc.free( serv ); + + api.markServerAsBusy( serv ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "internal server error" } ); + }; + + const address = z.fmt.allocPrintZ( alloc, "{s}/chat", .{serv} ) catch return; + defer alloc.free( address ); + startChat( r, address, params.v, uuid, &user ); +} + +fn setChatTitle( uuid: []const u8, chatfile: []const u8, title: []const u8 ) void { + var prefs = u.getSettings( uuid ) catch { + return z.debug.print( "setChatTitle called for invalid uuid: {s}\n", .{uuid} ); + }; defer prefs.db_entry.?.deinit(); + + if( prefs.chat_files == null and prefs.chat_files.?.files == null ) + return; + + const files = prefs.chat_files.?.files.?; + for( files ) |*f| { + if( memeql( u8, f.id, chatfile ) ) { + f.name = title; + break; + } + } + prefs.chat_files = u.ChatFiles{ .files = files }; + u.updateSettings( &prefs ) catch {}; +} + +// ------------------------------------------------------------------------------------------- +// generate ---------------------------------------------------------------------------------- + +fn generateChunkFn( chunk: []const u8, r: Request ) void { + var parts = z.mem.split( u8, chunk, "\n" ); + var part: ?[]const u8 = parts.first(); + while( part ) |p| : (part = parts.next()) { + if( p.len == 0 ) + continue; + + var buf = alloc.alloc( u8, p.len + 1 ) catch return; + defer alloc.free( buf ); + @memcpy( buf[0..p.len], p ); + buf[p.len] = '\n'; + + r.sendChunk( buf ) catch |e| { + z.debug.print( "error sending chunk: {any}\n", .{e} ); + }; + } +} + +fn startGenerate( r: Request, address: []const u8, params: GenerateReq, user: *u.UserEntry ) void { + const _model = model.getModelByDisplayName( params.model ) catch { + const json = u.jsonStringify( ErrRes{ .msg = "invalid model" } ) catch ""; + defer alloc.free( json ); + return r.sendChunk( json ) catch {}; + }; + + if( !model.canBeUsedByUser( _model.*, user ) ) { + return net.sendJson( r, .unauthorized, ErrRes{ .msg = "user is not authorized to use this model" } ); + } + + const gen_req = ModelInstanceGenerateReq{ + .options = ModelInstanceGenerateOptions{ + .model = _model.*, + }, + .prompt = params.prompt, + .suffix = params.suffix, + }; + + const params_str = u.jsonStringify( gen_req ) catch { + return net.sendJsonChunk( r, .bad_request, ErrRes{ .msg = "invalid request format" } ); + }; defer alloc.free( params_str ); + + r.setChunked() catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "internal server error" } ); + }; + const q = req.sendWithData( Request, .{ + .method = .POST, + .url = address, + .headers = &[_]z.http.Header{ .{ .name = "Content-Type", .value = "application/json" } }, + .body = params_str, + .chunk_fn = generateChunkFn, + .chunk_data = r + }, alloc ); + defer q.deinit(); + + if( !q.ok ) + handleInstanceError( r, &q ); + r.endStream() catch |e| { + return z.debug.print( "error closing connection {any} {any}", .{e, @errorReturnTrace() } ); + }; +} + +///route @/generate +fn generate( r: Request ) void { + if( net.handleInvalidPostReq( r ) ) return; + u.checkDbOnThread() catch return; + + const uuid = u.uuidFromApiOrAuthToken( r ) catch return; + defer alloc.free( uuid ); + + const params = u.jsonParse( GenerateReq, r.body.? ) catch |e| { + z.debug.print( "error parsing request {any}", .{e} ); + return net.sendJson( r, .bad_request, ErrRes{ .msg = "invalid request format" } ); + }; defer params.deinit(); + + var user = u.getEntry( uuid ) catch { + return net.sendJson( r, .not_found, ErrRes{ .msg = "user not found" } ); + }; defer user.db_entry.?.deinit(); + + u.checkSubscription( &user ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "internal server error" } ); + }; + + const serv = api.getAvailableServer( params.v.model, alloc ) orelse { + return net.sendJson( r, .service_unavailable, ErrRes{ .msg = "no server available" } ); + }; defer alloc.free( serv ); + + api.markServerAsBusy( serv ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "internal server error" } ); + }; + + const address = z.fmt.allocPrintZ( alloc, "{s}/generate", .{serv} ) catch return; + defer alloc.free( address ); + startGenerate( r, address, params.v, &user ); +} + + + +// ------------------------------------------------------------------------------------------- +// chat management (create-chat, delete-chat, get-chat, get-chat-name) ----------------------- + +///route @/create-chat +fn createChat( r: Request ) void { + if( net.handleInvalidPostReq( r ) ) return; + u.checkDbOnThread() catch return; + + const uuid = u.uuidFromApiOrAuthToken( r ) catch return; + defer alloc.free( uuid ); + + var userdata = u.getSettings( uuid ) catch { + return net.sendJson( r, .not_found, ErrRes{ .msg = "user not found" } ); + }; defer userdata.db_entry.?.deinit(); + + var chats = userdata.chat_files; + if( chats == null ) + chats = u.ChatFiles{ .files = &[_]u.ChatEntry{} }; + + const files = chats.?.files; + var list = ArrayList( u.ChatEntry ).init( alloc ); + defer list.deinit(); + + if( files != null and files.?.len > 0 ) + list.appendSlice( files.? ) catch {}; + + const chat_uuid = net.uuidv4(); + var buf = [_]u8{ 0 } ** 64; + const dup = z.fmt.bufPrintZ( &buf, "{s}", .{chat_uuid} ) catch &chat_uuid; + const new_entry = u.ChatEntry{ + .id = dup, + .name = "new chat" + }; + + list.append( new_entry ) catch {}; + chats.?.files = list.items; + userdata.chat_files = chats; + u.updateSettings( &userdata ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "internal server error" } ); + }; + + return net.sendJson( r, .ok, OkRes( .{ .chatId = dup } ) ); +} + +///route @/get-chat +fn getChat( r: Request ) void { + if( net.handleInvalidPostReq( r ) ) return; + u.checkDbOnThread() catch return; + + const uuid = u.uuidFromApiOrAuthToken( r ) catch return; + defer alloc.free( uuid ); + + const params = u.jsonParse( GetChatReq, r.body.? ) catch { + return net.sendJson( r, .bad_request, ErrRes{ .msg = "invalid request format" } ); + }; defer params.deinit(); + + const userdata = u.getSettings( uuid ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "user not found" } ); + }; defer userdata.db_entry.?.deinit(); + + const chats = userdata.chat_files; + if( chats == null or chats.?.files == null ) + return net.sendJson( r, .not_found, ErrRes{ .msg = "chat not found" } ); + + for( chats.?.files.? ) |file| { + const id = file.id; + if( memeql( u8, id, params.v.chatId ) ) { + var pathbuf: [256]u8 = undefined; + const slice = z.fmt.bufPrint( &pathbuf, "../data/chats/{s}.json", .{id} ) catch ""; + const contents = u.readFileCrypto( slice ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "error reading file" } ); + }; defer alloc.free( contents ); + + return net.sendJson( r, .ok, OkRes( .{ .name = file.name, .contents = contents } ) ); + } + } + + return net.sendJson( r, .not_found, ErrRes{ .msg = "chat not found" } ); +} + +///route @/delete-chat +fn deleteChat( r: Request ) void { + if( net.handleInvalidPostReq( r ) ) return; + u.checkDbOnThread() catch return; + + const uuid = u.uuidFromApiOrAuthToken( r ) catch return; + defer alloc.free( uuid ); + + const params = u.jsonParse( GetChatReq, r.body.? ) catch { + return net.sendJson( r, .bad_request, ErrRes{ .msg = "invalid request format" } ); + }; defer params.deinit(); + + var userdata = u.getSettings( uuid ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "user not found" } ); + }; defer userdata.db_entry.?.deinit(); + + const chats = userdata.chat_files; + if( chats == null or chats.?.files == null ) + return net.sendJson( r, .not_found, ErrRes{ .msg = "chat not found" } ); + + var updated_list = ArrayList( u.ChatEntry ).init( alloc ); + defer updated_list.deinit(); + + var found = false; + for( chats.?.files.? ) |file| { + const id = file.id; + if( memeql( u8, id, params.v.chatId ) ) { + var pathbuf: [256]u8 = undefined; + const slice = z.fmt.bufPrint( &pathbuf, "../data/chats/{s}.json", .{id} ) catch ""; + if( !u.fileExists( slice ) ) { + found = true; continue; + } + + u.deleteFile( slice ) catch continue; + found = true; + } + else { + updated_list.append( file ) catch {}; + } + } + + if( found ) { + userdata.chat_files.?.files = updated_list.items; + u.updateSettings( &userdata ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "error updating user data" } ); + }; + + return net.sendJson( r, .ok, OkRes( .{} ) ); + } + + return net.sendJson( r, .not_found, ErrRes{ .msg = "chat not found" } ); +} + +///route @/get-chat-name +fn getChatName( r: Request ) void { + if( net.handleInvalidPostReq( r ) ) return; + u.checkDbOnThread() catch return; + + const uuid = u.uuidFromApiOrAuthToken( r ) catch return; + defer alloc.free( uuid ); + + const params = u.jsonParse( GetChatReq, r.body.? ) catch { + return net.sendJson( r, .bad_request, ErrRes{ .msg = "invalid request format" } ); + }; defer params.deinit(); + + const userdata = u.getSettings( uuid ) catch { + return net.sendJson( r, .internal_server_error, ErrRes{ .msg = "user not found" } ); + }; defer userdata.db_entry.?.deinit(); + + const chats = userdata.chat_files; + if( chats == null or chats.?.files == null ) + return net.sendJson( r, .not_found, ErrRes{ .msg = "chat not found" } ); + + for( chats.?.files.? ) |file| { + const id = file.id; + if( memeql( u8, id, params.v.chatId ) ) { + return net.sendJson( r, .ok, OkRes( .{ .name = file.name } ) ); + } + } + + return net.sendJson( r, .not_found, ErrRes{ .msg = "file not found" } ); +} |
