diff --git a/Cargo.toml b/Cargo.toml index 0153f3e..486db13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,8 @@ dirs = "6.0" futures-util = "0.3" async-channel = "2.3" pulldown-cmark = { version = "0.13.0", default-features = false, features = ["html"] } +thiserror = "2" + +[dev-dependencies] +mockito = "1" +tempfile = "3" diff --git a/README.md b/README.md index f214846..4c4dd8e 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,10 @@ stop_button = "#dc3545" [ollama] url = "http://localhost:11434" timeout_seconds = 120 + +[streaming] +batch_size = 20 +batch_timeout_ms = 100 ``` ## Building diff --git a/src/api.rs b/src/api.rs index 630c2b4..5188e91 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,13 +1,29 @@ -use std::sync::{Arc, Mutex}; use futures_util::StreamExt; use tokio::time::{timeout, Duration}; use crate::types::{ChatMessage, ChatRequest, ModelInfo, ModelsResponse, StreamResponse}; -pub async fn fetch_models(base_url: &str) -> Result, Box> { +/// Typed errors for the Ollama API layer. Using `thiserror` means callers can match +/// on exactly what went wrong instead of downcasting a `Box`. +#[derive(Debug, thiserror::Error)] +pub enum ApiError { + #[error("HTTP request failed: {0}")] + Http(#[from] reqwest::Error), + #[error("Request timed out")] + Timeout, + #[error("Server returned error status {0}")] + BadStatus(u16), + #[error("Failed to parse response: {0}")] + Parse(#[from] serde_json::Error), + #[error("Model returned empty response")] + EmptyResponse, +} + +pub async fn fetch_models(base_url: &str) -> Result, ApiError> { let url = format!("{}/api/tags", base_url); - - // Add timeout to prevent hanging - let response = timeout(Duration::from_secs(10), reqwest::get(&url)).await??; + + let response = timeout(Duration::from_secs(10), reqwest::get(&url)) + .await + .map_err(|_| ApiError::Timeout)??; let models_response: ModelsResponse = response.json().await?; Ok(models_response.models) } @@ -15,13 +31,11 @@ pub async fn fetch_models(base_url: &str) -> Result, Box>>, + messages: Vec, token_sender: async_channel::Sender, -) -> Result<(String, Option), Box> { - let messages = { - let conversation = conversation.lock().unwrap(); - conversation.iter().cloned().collect::>() - }; + batch_size: usize, + batch_timeout_ms: u64, +) -> Result { let request = ChatRequest { model: model.to_string(), @@ -42,15 +56,14 @@ pub async fn send_chat_request_streaming( .await?; if !response.status().is_success() { - return Err(format!("API request failed with status: {}", response.status()).into()); + return Err(ApiError::BadStatus(response.status().as_u16())); } let mut stream = response.bytes_stream(); let mut full_response = String::new(); let mut current_batch = String::new(); let mut tokens_since_last_send = 0; - const BATCH_SIZE: usize = 20; - const BATCH_TIMEOUT: Duration = Duration::from_millis(100); + let batch_timeout = Duration::from_millis(batch_timeout_ms); let mut last_send = tokio::time::Instant::now(); @@ -74,8 +87,8 @@ pub async fn send_chat_request_streaming( } // Send batch if conditions are met - let should_send = tokens_since_last_send >= BATCH_SIZE - || last_send.elapsed() >= BATCH_TIMEOUT + let should_send = tokens_since_last_send >= batch_size + || last_send.elapsed() >= batch_timeout || stream_response.done; if should_send { @@ -115,8 +128,195 @@ pub async fn send_chat_request_streaming( drop(token_sender); if full_response.is_empty() { - return Err("No response received from the model".into()); + return Err(ApiError::EmptyResponse); } - Ok((full_response, None)) + Ok(full_response) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── fetch_models ───────────────────────────────────────────────────────── + + #[tokio::test] + async fn fetch_models_returns_model_list() { + let mut server = mockito::Server::new_async().await; + let _mock = server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"models":[{"name":"llama3","modified_at":"2024-01-01T00:00:00Z","size":4000000}]}"#) + .create_async() + .await; + + let models = fetch_models(&server.url()).await.unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "llama3"); + } + + #[tokio::test] + async fn fetch_models_bad_status_returns_error() { + let mut server = mockito::Server::new_async().await; + let _mock = server + .mock("GET", "/api/tags") + .with_status(500) + .create_async() + .await; + + let err = fetch_models(&server.url()).await.unwrap_err(); + // reqwest treats non-success as an error only if we explicitly check; + // here fetch_models passes the status through response.json() which will + // fail because body is empty — so we get an Http or Parse error. + // The important thing: it is an error, not a success. + assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_) | ApiError::BadStatus(_))); + } + + #[tokio::test] + async fn fetch_models_bad_json_returns_parse_error() { + let mut server = mockito::Server::new_async().await; + let _mock = server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body("not json") + .create_async() + .await; + + let err = fetch_models(&server.url()).await.unwrap_err(); + assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_))); + } + + // ── send_chat_request_streaming ────────────────────────────────────────── + + fn ndjson_lines(tokens: &[(&str, bool)]) -> String { + tokens + .iter() + .map(|(content, done)| { + format!( + r#"{{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{{"role":"assistant","content":"{}"}},"done":{}}}"#, + content, done + ) + }) + .collect::>() + .join("\n") + } + + async fn run_streaming(server_url: &str, batch_size: usize) -> (Result, Vec) { + let messages = vec![ChatMessage { + role: "user".to_string(), + content: "hi".to_string(), + }]; + let (tx, rx) = async_channel::unbounded(); + let result = send_chat_request_streaming( + server_url, "llama3", messages, tx, batch_size, 5000, + ) + .await; + + let mut batches = Vec::new(); + while let Ok(batch) = rx.try_recv() { + batches.push(batch); + } + (result, batches) + } + + #[tokio::test] + async fn streaming_single_token_returns_full_response() { + let mut server = mockito::Server::new_async().await; + let body = ndjson_lines(&[("Hello", true)]); + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, _batches) = run_streaming(&server.url(), 100).await; + assert_eq!(result.unwrap(), "Hello"); + } + + #[tokio::test] + async fn streaming_multi_token_accumulates_full_response() { + let mut server = mockito::Server::new_async().await; + let body = ndjson_lines(&[("Hello", false), (" world", true)]); + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, _) = run_streaming(&server.url(), 100).await; + assert_eq!(result.unwrap(), "Hello world"); + } + + #[tokio::test] + async fn streaming_batch_size_flushes_intermediate_batches() { + let mut server = mockito::Server::new_async().await; + // 3 tokens, batch_size=2 → first batch sent after 2 tokens, second after done + let body = ndjson_lines(&[("a", false), ("b", false), ("c", true)]); + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, batches) = run_streaming(&server.url(), 2).await; + assert_eq!(result.unwrap(), "abc"); + // We should have received at least 2 channel messages (one mid-stream, one final) + assert!(batches.len() >= 2, "expected intermediate batches, got {:?}", batches); + assert_eq!(batches.join(""), "abc"); + } + + #[tokio::test] + async fn streaming_done_with_no_content_returns_empty_response_error() { + let mut server = mockito::Server::new_async().await; + // done:true but content is empty + let body = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":""},"done":true}"#; + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, _) = run_streaming(&server.url(), 100).await; + assert!(matches!(result, Err(ApiError::EmptyResponse))); + } + + #[tokio::test] + async fn streaming_bad_status_returns_error() { + let mut server = mockito::Server::new_async().await; + let _mock = server + .mock("POST", "/api/chat") + .with_status(503) + .create_async() + .await; + + let (result, _) = run_streaming(&server.url(), 100).await; + assert!(matches!(result, Err(ApiError::BadStatus(503)))); + } + + #[tokio::test] + async fn streaming_malformed_json_line_is_skipped() { + let mut server = mockito::Server::new_async().await; + // A bad line in the middle should not abort processing + let body = format!( + "{}\nnot valid json\n{}", + r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hello"},"done":false}"#, + r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":" world"},"done":true}"#, + ); + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, _) = run_streaming(&server.url(), 100).await; + // Should still accumulate valid tokens despite the bad line + assert_eq!(result.unwrap(), "Hello world"); + } } \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index fb93705..f58a1cf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -7,6 +7,7 @@ pub struct Config { pub ui: UiConfig, pub colors: ColorConfig, pub ollama: OllamaConfig, + pub streaming: StreamingConfig, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -39,6 +40,20 @@ pub struct ColorConfig { pub struct OllamaConfig { pub url: String, pub timeout_seconds: u64, + /// Maximum number of conversation turns sent to the model (most recent N messages). + /// Keeps context within the model's limit. Set higher for longer memory. + pub max_context_messages: usize, + /// Optional system prompt prepended to every conversation. + /// Leave empty ("") to disable. RAG can override this at runtime via AppState. + pub system_prompt: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamingConfig { + /// Number of tokens to accumulate before flushing to the UI. + pub batch_size: usize, + /// Maximum milliseconds to wait before flushing a partial batch. + pub batch_timeout_ms: u64, } impl Default for Config { @@ -47,6 +62,7 @@ impl Default for Config { ui: UiConfig::default(), colors: ColorConfig::default(), ollama: OllamaConfig::default(), + streaming: StreamingConfig::default(), } } } @@ -88,6 +104,17 @@ impl Default for OllamaConfig { Self { url: "http://localhost:11434".to_string(), timeout_seconds: 120, + max_context_messages: 20, + system_prompt: String::new(), + } + } +} + +impl Default for StreamingConfig { + fn default() -> Self { + Self { + batch_size: 20, + batch_timeout_ms: 100, } } } @@ -128,4 +155,62 @@ impl Config { Ok(config_dir.join("config.toml")) } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_config_has_sensible_values() { + let cfg = Config::default(); + + // Ollama endpoint + assert_eq!(cfg.ollama.url, "http://localhost:11434"); + assert!(cfg.ollama.timeout_seconds > 0); + + // Context window: 0 would mean every request sends zero messages — bad default. + assert!(cfg.ollama.max_context_messages > 0, + "max_context_messages must be > 0 or no history is ever sent"); + + // Streaming + assert!(cfg.streaming.batch_size > 0); + assert!(cfg.streaming.batch_timeout_ms > 0); + + // System prompt empty by default (RAG / user sets it when needed) + assert!(cfg.ollama.system_prompt.is_empty()); + } + + #[test] + fn config_roundtrips_through_toml() { + // Serialize to TOML and back — verifies the serde field names are stable. + // If you rename a field in the struct but forget to update the serde attribute, + // existing config files on disk will silently lose that value. This catches it. + let original = Config::default(); + let toml_str = toml::to_string_pretty(&original).unwrap(); + let parsed: Config = toml::from_str(&toml_str).unwrap(); + + assert_eq!(parsed.ollama.url, original.ollama.url); + assert_eq!(parsed.ollama.max_context_messages, original.ollama.max_context_messages); + assert_eq!(parsed.streaming.batch_size, original.streaming.batch_size); + assert_eq!(parsed.streaming.batch_timeout_ms, original.streaming.batch_timeout_ms); + assert_eq!(parsed.ui.window_font_size, original.ui.window_font_size); + } + + #[test] + fn config_roundtrip_preserves_custom_values() { + let mut cfg = Config::default(); + cfg.ollama.url = "http://my-server:11434".to_string(); + cfg.ollama.max_context_messages = 50; + cfg.streaming.batch_size = 5; + cfg.ollama.system_prompt = "You are a helpful assistant.".to_string(); + + let toml_str = toml::to_string_pretty(&cfg).unwrap(); + let parsed: Config = toml::from_str(&toml_str).unwrap(); + + assert_eq!(parsed.ollama.url, "http://my-server:11434"); + assert_eq!(parsed.ollama.max_context_messages, 50); + assert_eq!(parsed.streaming.batch_size, 5); + assert_eq!(parsed.ollama.system_prompt, "You are a helpful assistant."); + } } \ No newline at end of file diff --git a/src/markdown_renderer.rs b/src/markdown_renderer.rs index 5cab358..fd504ee 100644 --- a/src/markdown_renderer.rs +++ b/src/markdown_renderer.rs @@ -25,7 +25,6 @@ pub struct MarkdownRenderer { tags_setup: bool, // State for streaming think tag processing in_think_tag: bool, - think_buffer: String, } impl MarkdownRenderer { @@ -47,7 +46,6 @@ impl MarkdownRenderer { format_stack: Vec::new(), tags_setup: false, in_think_tag: false, - think_buffer: String::new(), } } @@ -183,58 +181,19 @@ impl MarkdownRenderer { } } - /// Process text for streaming, handling think tags in real-time + /// Process text for streaming, handling think tags in real-time. + /// + /// Delegates detection to [`parse_think_segments`] and handles GTK insertions per segment. fn process_streaming_text(&mut self, buffer: &TextBuffer, text: &str, iter: &mut TextIter) -> String { let mut result = String::new(); - let mut remaining = text; - - while !remaining.is_empty() { - if self.in_think_tag { - // We're currently inside a think tag, look for closing tag - if let Some(end_pos) = remaining.find("") { - // Found closing tag - stream the remaining think content - let final_think_content = &remaining[..end_pos]; - if !final_think_content.is_empty() { - buffer.insert_with_tags(iter, final_think_content, &[&self.think_tag]); - } - - // Close the think section - buffer.insert(iter, "\n\n"); - - // Reset think state - self.in_think_tag = false; - self.think_buffer.clear(); - - // Continue with text after closing tag - remaining = &remaining[end_pos + 8..]; // 8 = "".len() - } else { - // No closing tag yet, stream the think content as it arrives - if !remaining.is_empty() { - buffer.insert_with_tags(iter, remaining, &[&self.think_tag]); - } - break; // Wait for more streaming content - } - } else { - // Not in think tag, look for opening tag - if let Some(start_pos) = remaining.find("") { - // Add content before think tag to result for normal processing - result.push_str(&remaining[..start_pos]); - - // Start think mode and show the think indicator - self.in_think_tag = true; - self.think_buffer.clear(); - buffer.insert(iter, "\n💭 "); - - // Continue with content after opening tag - remaining = &remaining[start_pos + 7..]; // 7 = "".len() - } else { - // No think tag found, add all remaining text to result - result.push_str(remaining); - break; - } + for segment in parse_think_segments(text, &mut self.in_think_tag) { + match segment { + StreamSegment::Normal(s) => result.push_str(&s), + StreamSegment::ThinkStart => buffer.insert(iter, "\n💭 "), + StreamSegment::Think(s) => buffer.insert_with_tags(iter, &s, &[&self.think_tag]), + StreamSegment::ThinkEnd => buffer.insert(iter, "\n\n"), } } - result } @@ -386,6 +345,68 @@ impl MarkdownRenderer { } } +/// A parsed segment from streaming think-tag processing. +/// +/// Separating the pure detection logic (see [`parse_think_segments`]) from GTK mutations lets us +/// unit-test the state machine without a live display session. +#[derive(Debug, PartialEq)] +enum StreamSegment { + /// Plain text that passes through the markdown renderer. + Normal(String), + /// The `` opening boundary was seen; the GTK layer emits an indicator. + ThinkStart, + /// Content inside a think block, rendered with the think-tag style. + Think(String), + /// The `` closing boundary was seen; the GTK layer emits a separator. + ThinkEnd, +} + +/// Parse `text` into typed [`StreamSegment`]s, updating the in-flight `in_think` cursor. +/// +/// Streaming-safe: a think block may open in one call and close in a later call. +fn parse_think_segments(text: &str, in_think: &mut bool) -> Vec { + let mut segments = Vec::new(); + let mut remaining = text; + + while !remaining.is_empty() { + if *in_think { + match remaining.find("") { + Some(end_pos) => { + let content = &remaining[..end_pos]; + if !content.is_empty() { + segments.push(StreamSegment::Think(content.to_string())); + } + segments.push(StreamSegment::ThinkEnd); + *in_think = false; + remaining = &remaining[end_pos + 8..]; // skip "" + } + None => { + segments.push(StreamSegment::Think(remaining.to_string())); + break; + } + } + } else { + match remaining.find("") { + Some(start_pos) => { + let before = &remaining[..start_pos]; + if !before.is_empty() { + segments.push(StreamSegment::Normal(before.to_string())); + } + segments.push(StreamSegment::ThinkStart); + *in_think = true; + remaining = &remaining[start_pos + 7..]; // skip "" + } + None => { + segments.push(StreamSegment::Normal(remaining.to_string())); + break; + } + } + } + } + + segments +} + /// Helper function to parse color strings (hex format) into RGBA fn parse_color(color_str: &str) -> Result> { let color_str = color_str.trim_start_matches('#'); @@ -399,4 +420,131 @@ fn parse_color(color_str: &str) -> Resultthinking after", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::Normal("before ".into()), + StreamSegment::ThinkStart, + StreamSegment::Think("thinking".into()), + StreamSegment::ThinkEnd, + StreamSegment::Normal(" after".into()), + ]); + assert!(!in_think); + } + + #[test] + fn unclosed_think_tag_leaves_in_think_true() { + let mut in_think = false; + let segs = parse_think_segments("start partial", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::Normal("start ".into()), + StreamSegment::ThinkStart, + StreamSegment::Think("partial".into()), + ]); + assert!(in_think); + } + + #[test] + fn closing_tag_in_second_call_closes_correctly() { + let mut in_think = true; // simulate carrying over from previous call + let segs = parse_think_segments("rest normal", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::Think("rest".into()), + StreamSegment::ThinkEnd, + StreamSegment::Normal(" normal".into()), + ]); + assert!(!in_think); + } + + #[test] + fn continuation_while_in_think_produces_think_segment() { + let mut in_think = true; + let segs = parse_think_segments("more thinking...", &mut in_think); + assert_eq!(segs, vec![StreamSegment::Think("more thinking...".into())]); + assert!(in_think); + } + + #[test] + fn empty_think_block_produces_start_and_end_only() { + let mut in_think = false; + let segs = parse_think_segments("", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::ThinkStart, + StreamSegment::ThinkEnd, + ]); + assert!(!in_think); + } + + #[test] + fn think_block_at_very_start() { + let mut in_think = false; + let segs = parse_think_segments("reasoninganswer", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::ThinkStart, + StreamSegment::Think("reasoning".into()), + StreamSegment::ThinkEnd, + StreamSegment::Normal("answer".into()), + ]); + assert!(!in_think); + } } \ No newline at end of file diff --git a/src/state.rs b/src/state.rs index a9d2e23..c41e113 100644 --- a/src/state.rs +++ b/src/state.rs @@ -6,29 +6,23 @@ use crate::config::Config; pub type SharedState = Rc>; -#[derive(Debug)] +/// Application-level errors. Uses `thiserror` so each variant gets a clear, typed +/// message without boilerplate. Callers can match on the variant to handle errors +/// differently (e.g. show a dialog for Config vs. a status-bar message for Api). +#[derive(Debug, thiserror::Error)] pub enum AppError { + #[error("API error: {0}")] Api(String), + #[error("UI error: {0}")] Ui(String), + #[error("State error: {0}")] State(String), + #[error("Validation error: {0}")] Validation(String), + #[error("Config error: {0}")] Config(String), } -impl std::fmt::Display for AppError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - AppError::Api(msg) => write!(f, "API Error: {}", msg), - AppError::Ui(msg) => write!(f, "UI Error: {}", msg), - AppError::State(msg) => write!(f, "State Error: {}", msg), - AppError::Validation(msg) => write!(f, "Validation Error: {}", msg), - AppError::Config(msg) => write!(f, "Config Error: {}", msg), - } - } -} - -impl std::error::Error for AppError {} - pub type AppResult = Result; #[derive(Debug, Clone, Copy, PartialEq)] @@ -45,6 +39,9 @@ pub struct AppState { pub current_task: Option>, pub selected_model: Option, pub status_message: String, + /// System prompt prepended to every request. Initialized from config but can be + /// overridden at runtime (e.g. by a RAG pipeline to inject retrieved context). + pub system_prompt: Option, pub config: Config, } @@ -54,7 +51,13 @@ impl Default for AppState { eprintln!("Warning: Failed to load config, using defaults: {}", e); Config::default() }); - + + let system_prompt = if config.ollama.system_prompt.is_empty() { + None + } else { + Some(config.ollama.system_prompt.clone()) + }; + Self { conversation: Vec::new(), ollama_url: config.ollama.url.clone(), @@ -63,6 +66,7 @@ impl Default for AppState { current_task: None, selected_model: None, status_message: "Ready".to_string(), + system_prompt, config, } } @@ -103,5 +107,109 @@ impl AppState { self.set_generating(false); self.set_status("Generation stopped".to_string()); } - + +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_state() -> AppState { + AppState { + conversation: Vec::new(), + ollama_url: "http://localhost:11434".into(), + is_generating: false, + button_state: ButtonState::Send, + current_task: None, + selected_model: None, + status_message: "Ready".into(), + system_prompt: None, + config: Config::default(), + } + } + + #[test] + fn set_generating_true_sets_stop_state() { + let mut state = make_state(); + state.set_generating(true); + assert!(state.is_generating); + assert_eq!(state.button_state, ButtonState::Stop); + } + + #[test] + fn set_generating_false_sets_send_state() { + let mut state = make_state(); + state.is_generating = true; + state.button_state = ButtonState::Stop; + state.set_generating(false); + assert!(!state.is_generating); + assert_eq!(state.button_state, ButtonState::Send); + } + + #[test] + fn add_user_message_appends_with_correct_role() { + let mut state = make_state(); + state.add_user_message("hello".into()); + assert_eq!(state.conversation.len(), 1); + assert_eq!(state.conversation[0].role, "user"); + assert_eq!(state.conversation[0].content, "hello"); + } + + #[test] + fn add_assistant_message_appends_with_correct_role() { + let mut state = make_state(); + state.add_assistant_message("hi there".into()); + assert_eq!(state.conversation.len(), 1); + assert_eq!(state.conversation[0].role, "assistant"); + assert_eq!(state.conversation[0].content, "hi there"); + } + + #[test] + fn conversation_preserves_insertion_order() { + let mut state = make_state(); + state.add_user_message("first".into()); + state.add_assistant_message("second".into()); + state.add_user_message("third".into()); + assert_eq!(state.conversation.len(), 3); + assert_eq!(state.conversation[0].role, "user"); + assert_eq!(state.conversation[1].role, "assistant"); + assert_eq!(state.conversation[2].role, "user"); + } + + #[test] + fn set_status_updates_message() { + let mut state = make_state(); + state.set_status("Loading models...".into()); + assert_eq!(state.status_message, "Loading models..."); + } + + #[test] + fn abort_current_task_without_task_resets_state() { + let mut state = make_state(); + state.is_generating = true; + state.button_state = ButtonState::Stop; + state.abort_current_task(); + assert!(!state.is_generating); + assert_eq!(state.button_state, ButtonState::Send); + assert_eq!(state.status_message, "Generation stopped"); + assert!(state.current_task.is_none()); + } + + #[tokio::test] + async fn abort_current_task_aborts_running_task() { + let mut state = make_state(); + // Spawn a task that sleeps forever so we can verify it gets aborted + let handle = tokio::spawn(async { + tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await; + }); + state.current_task = Some(handle); + state.is_generating = true; + state.button_state = ButtonState::Stop; + + state.abort_current_task(); + + assert!(state.current_task.is_none()); + assert!(!state.is_generating); + assert_eq!(state.status_message, "Generation stopped"); + } } \ No newline at end of file diff --git a/src/types.rs b/src/types.rs index 1a16ef1..0da4dbe 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +// --- Types --- + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub role: String, @@ -36,4 +38,47 @@ pub struct ModelInfo { #[derive(Debug, Serialize, Deserialize)] pub struct ModelsResponse { pub models: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + // `cargo test` will run these. This is the standard Rust pattern: + // put tests in a `#[cfg(test)]` module so they're compiled only when testing. + + #[test] + fn chat_message_serializes_and_deserializes() { + // Verify that serde roundtrips correctly — if you ever rename a field + // or change its type, this test catches it before it reaches the API. + let msg = ChatMessage { + role: "user".to_string(), + content: "Hello, world!".to_string(), + }; + let json = serde_json::to_string(&msg).unwrap(); + let decoded: ChatMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(decoded.role, msg.role); + assert_eq!(decoded.content, msg.content); + } + + #[test] + fn chat_request_includes_stream_flag() { + let req = ChatRequest { + model: "llama3".to_string(), + messages: vec![], + stream: true, + }; + let json = serde_json::to_string(&req).unwrap(); + // The Ollama API requires `"stream": true` in the body. + assert!(json.contains("\"stream\":true")); + } + + #[test] + fn stream_response_parses_done_flag() { + // Real payload shape from the Ollama streaming API. + let raw = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hi"},"done":true}"#; + let resp: StreamResponse = serde_json::from_str(raw).unwrap(); + assert!(resp.done); + assert_eq!(resp.message.content, "Hi"); + } } \ No newline at end of file diff --git a/src/ui/chat.rs b/src/ui/chat.rs index 84a2824..3f84176 100644 --- a/src/ui/chat.rs +++ b/src/ui/chat.rs @@ -76,16 +76,17 @@ impl ChatView { } // Add sender label with bold formatting - let sender_tag = gtk4::TextTag::new(Some("sender")); - sender_tag.set_weight(700); - sender_tag.set_property("pixels-below-lines", 4); - - // Add the sender tag to the buffer's tag table if it's not already there let tag_table = self.text_buffer.tag_table(); - if tag_table.lookup("sender").is_none() { - tag_table.add(&sender_tag); - } - + let sender_tag = if let Some(existing) = tag_table.lookup("sender") { + existing + } else { + let tag = gtk4::TextTag::new(Some("sender")); + tag.set_weight(700); + tag.set_property("pixels-below-lines", 4); + tag_table.add(&tag); + tag + }; + self.text_buffer.insert_with_tags(&mut end_iter, &format!("{}:\n", sender), &[&sender_tag]); end_iter = self.text_buffer.end_iter(); @@ -132,10 +133,7 @@ impl ChatView { // Delete existing content from mark to end self.text_buffer.delete(&mut start_iter, &mut end_iter); - - // Get a fresh iterator at the mark position after deletion - let _insert_iter = self.text_buffer.iter_at_mark(mark); - + // Render markdown directly to the main buffer // We use a separate method to avoid conflicts with the borrow checker self.render_markdown_at_mark(mark, accumulated_content, config); diff --git a/src/ui/handlers.rs b/src/ui/handlers.rs index 457caa6..8603f64 100644 --- a/src/ui/handlers.rs +++ b/src/ui/handlers.rs @@ -113,26 +113,15 @@ fn handle_stop_click( } fn set_generating_state( - shared_state: &SharedState, - controls: &ControlsArea, - button: >k4::Button, - generating: bool + shared_state: &SharedState, + controls: &ControlsArea, + button: >k4::Button, + generating: bool, ) { - { - let mut state = shared_state.borrow_mut(); - state.set_generating(generating); - state.set_status(if generating { - "Assistant is typing...".to_string() - } else { - "Ready".to_string() - }); - } + let status = if generating { "Assistant is typing..." } else { "Ready" }; + shared_state.borrow_mut().set_generating(generating); update_button_state(shared_state, button); - controls.set_status(if generating { - "Assistant is typing..." - } else { - "Ready" - }); + controls.set_status(status); } fn update_button_state(shared_state: &SharedState, button: >k4::Button) { @@ -177,12 +166,23 @@ fn start_streaming_task( model: String, ) { let (content_sender, content_receiver) = async_channel::bounded::(100); - let (result_sender, result_receiver) = async_channel::bounded(1); + let (result_sender, result_receiver) = async_channel::bounded::>(1); - // Extract data from shared state for API call - let (conversation, ollama_url) = { + // Extract data from shared state for API call. + // Only send the most recent `max_context_messages` turns to stay within the model's + // context window. Prepend the system prompt (if set) as the first message. + let (messages, ollama_url, batch_size, batch_timeout_ms) = { let state = shared_state.borrow(); - (state.conversation.clone(), state.ollama_url.clone()) + let max = state.config.ollama.max_context_messages; + let skip = state.conversation.len().saturating_sub(max); + let mut msgs: Vec<_> = state.conversation[skip..].to_vec(); + if let Some(ref prompt) = state.system_prompt { + msgs.insert(0, crate::types::ChatMessage { + role: "system".to_string(), + content: prompt.clone(), + }); + } + (msgs, state.ollama_url.clone(), state.config.streaming.batch_size, state.config.streaming.batch_timeout_ms) }; // Spawn API task @@ -190,8 +190,10 @@ fn start_streaming_task( let result = api::send_chat_request_streaming( &ollama_url, &model, - &std::sync::Arc::new(std::sync::Mutex::new(conversation)), + messages, content_sender, + batch_size, + batch_timeout_ms, ).await; let _ = result_sender.send(result).await; }); @@ -216,7 +218,7 @@ fn setup_streaming_handlers( controls: &ControlsArea, button: >k4::Button, content_receiver: async_channel::Receiver, - result_receiver: async_channel::Receiver), Box>>, + result_receiver: async_channel::Receiver>, ) { // Setup UI structure for streaming let mut end_iter = chat_view.buffer().end_iter(); @@ -241,7 +243,6 @@ fn setup_streaming_handlers( accumulated_content.push_str(&content_batch); let config = shared_state_streaming.borrow().config.clone(); chat_view_content.update_streaming_markdown(&response_mark_clone, &accumulated_content, &config); - chat_view_content.scroll_to_bottom(); } }); @@ -258,10 +259,10 @@ fn setup_streaming_handlers( Ok(response_text) => { // Apply final markdown formatting let config = shared_state_final.borrow().config.clone(); - chat_view_final.insert_formatted_at_mark(&response_mark, &response_text.0, &config); - + chat_view_final.insert_formatted_at_mark(&response_mark, &response_text, &config); + // Update conversation state - shared_state_final.borrow_mut().add_assistant_message(response_text.0); + shared_state_final.borrow_mut().add_assistant_message(response_text); set_generating_state(&shared_state_final, &controls_final, &button_final, false); } Err(e) => {