diff --git a/Cargo.toml b/Cargo.toml index 486db13..0153f3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,8 +15,3 @@ dirs = "6.0" futures-util = "0.3" async-channel = "2.3" pulldown-cmark = { version = "0.13.0", default-features = false, features = ["html"] } -thiserror = "2" - -[dev-dependencies] -mockito = "1" -tempfile = "3" diff --git a/README.md b/README.md index 4c4dd8e..f214846 100644 --- a/README.md +++ b/README.md @@ -61,10 +61,6 @@ stop_button = "#dc3545" [ollama] url = "http://localhost:11434" timeout_seconds = 120 - -[streaming] -batch_size = 20 -batch_timeout_ms = 100 ``` ## Building diff --git a/src/api.rs b/src/api.rs index 5188e91..630c2b4 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,29 +1,13 @@ +use std::sync::{Arc, Mutex}; use futures_util::StreamExt; use tokio::time::{timeout, Duration}; use crate::types::{ChatMessage, ChatRequest, ModelInfo, ModelsResponse, StreamResponse}; -/// Typed errors for the Ollama API layer. Using `thiserror` means callers can match -/// on exactly what went wrong instead of downcasting a `Box`. -#[derive(Debug, thiserror::Error)] -pub enum ApiError { - #[error("HTTP request failed: {0}")] - Http(#[from] reqwest::Error), - #[error("Request timed out")] - Timeout, - #[error("Server returned error status {0}")] - BadStatus(u16), - #[error("Failed to parse response: {0}")] - Parse(#[from] serde_json::Error), - #[error("Model returned empty response")] - EmptyResponse, -} - -pub async fn fetch_models(base_url: &str) -> Result, ApiError> { +pub async fn fetch_models(base_url: &str) -> Result, Box> { let url = format!("{}/api/tags", base_url); - - let response = timeout(Duration::from_secs(10), reqwest::get(&url)) - .await - .map_err(|_| ApiError::Timeout)??; + + // Add timeout to prevent hanging + let response = timeout(Duration::from_secs(10), reqwest::get(&url)).await??; let models_response: ModelsResponse = response.json().await?; Ok(models_response.models) } @@ -31,11 +15,13 @@ pub async fn fetch_models(base_url: &str) -> Result, ApiError> { pub async fn send_chat_request_streaming( base_url: &str, model: &str, - messages: Vec, + conversation: &Arc>>, token_sender: async_channel::Sender, - batch_size: usize, - batch_timeout_ms: u64, -) -> Result { +) -> Result<(String, Option), Box> { + let messages = { + let conversation = conversation.lock().unwrap(); + conversation.iter().cloned().collect::>() + }; let request = ChatRequest { model: model.to_string(), @@ -56,14 +42,15 @@ pub async fn send_chat_request_streaming( .await?; if !response.status().is_success() { - return Err(ApiError::BadStatus(response.status().as_u16())); + return Err(format!("API request failed with status: {}", response.status()).into()); } let mut stream = response.bytes_stream(); let mut full_response = String::new(); let mut current_batch = String::new(); let mut tokens_since_last_send = 0; - let batch_timeout = Duration::from_millis(batch_timeout_ms); + const BATCH_SIZE: usize = 20; + const BATCH_TIMEOUT: Duration = Duration::from_millis(100); let mut last_send = tokio::time::Instant::now(); @@ -87,8 +74,8 @@ pub async fn send_chat_request_streaming( } // Send batch if conditions are met - let should_send = tokens_since_last_send >= batch_size - || last_send.elapsed() >= batch_timeout + let should_send = tokens_since_last_send >= BATCH_SIZE + || last_send.elapsed() >= BATCH_TIMEOUT || stream_response.done; if should_send { @@ -128,195 +115,8 @@ pub async fn send_chat_request_streaming( drop(token_sender); if full_response.is_empty() { - return Err(ApiError::EmptyResponse); + return Err("No response received from the model".into()); } - Ok(full_response) -} - -#[cfg(test)] -mod tests { - use super::*; - - // ── fetch_models ───────────────────────────────────────────────────────── - - #[tokio::test] - async fn fetch_models_returns_model_list() { - let mut server = mockito::Server::new_async().await; - let _mock = server - .mock("GET", "/api/tags") - .with_status(200) - .with_header("content-type", "application/json") - .with_body(r#"{"models":[{"name":"llama3","modified_at":"2024-01-01T00:00:00Z","size":4000000}]}"#) - .create_async() - .await; - - let models = fetch_models(&server.url()).await.unwrap(); - assert_eq!(models.len(), 1); - assert_eq!(models[0].name, "llama3"); - } - - #[tokio::test] - async fn fetch_models_bad_status_returns_error() { - let mut server = mockito::Server::new_async().await; - let _mock = server - .mock("GET", "/api/tags") - .with_status(500) - .create_async() - .await; - - let err = fetch_models(&server.url()).await.unwrap_err(); - // reqwest treats non-success as an error only if we explicitly check; - // here fetch_models passes the status through response.json() which will - // fail because body is empty — so we get an Http or Parse error. - // The important thing: it is an error, not a success. - assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_) | ApiError::BadStatus(_))); - } - - #[tokio::test] - async fn fetch_models_bad_json_returns_parse_error() { - let mut server = mockito::Server::new_async().await; - let _mock = server - .mock("GET", "/api/tags") - .with_status(200) - .with_header("content-type", "application/json") - .with_body("not json") - .create_async() - .await; - - let err = fetch_models(&server.url()).await.unwrap_err(); - assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_))); - } - - // ── send_chat_request_streaming ────────────────────────────────────────── - - fn ndjson_lines(tokens: &[(&str, bool)]) -> String { - tokens - .iter() - .map(|(content, done)| { - format!( - r#"{{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{{"role":"assistant","content":"{}"}},"done":{}}}"#, - content, done - ) - }) - .collect::>() - .join("\n") - } - - async fn run_streaming(server_url: &str, batch_size: usize) -> (Result, Vec) { - let messages = vec![ChatMessage { - role: "user".to_string(), - content: "hi".to_string(), - }]; - let (tx, rx) = async_channel::unbounded(); - let result = send_chat_request_streaming( - server_url, "llama3", messages, tx, batch_size, 5000, - ) - .await; - - let mut batches = Vec::new(); - while let Ok(batch) = rx.try_recv() { - batches.push(batch); - } - (result, batches) - } - - #[tokio::test] - async fn streaming_single_token_returns_full_response() { - let mut server = mockito::Server::new_async().await; - let body = ndjson_lines(&[("Hello", true)]); - let _mock = server - .mock("POST", "/api/chat") - .with_status(200) - .with_body(body) - .create_async() - .await; - - let (result, _batches) = run_streaming(&server.url(), 100).await; - assert_eq!(result.unwrap(), "Hello"); - } - - #[tokio::test] - async fn streaming_multi_token_accumulates_full_response() { - let mut server = mockito::Server::new_async().await; - let body = ndjson_lines(&[("Hello", false), (" world", true)]); - let _mock = server - .mock("POST", "/api/chat") - .with_status(200) - .with_body(body) - .create_async() - .await; - - let (result, _) = run_streaming(&server.url(), 100).await; - assert_eq!(result.unwrap(), "Hello world"); - } - - #[tokio::test] - async fn streaming_batch_size_flushes_intermediate_batches() { - let mut server = mockito::Server::new_async().await; - // 3 tokens, batch_size=2 → first batch sent after 2 tokens, second after done - let body = ndjson_lines(&[("a", false), ("b", false), ("c", true)]); - let _mock = server - .mock("POST", "/api/chat") - .with_status(200) - .with_body(body) - .create_async() - .await; - - let (result, batches) = run_streaming(&server.url(), 2).await; - assert_eq!(result.unwrap(), "abc"); - // We should have received at least 2 channel messages (one mid-stream, one final) - assert!(batches.len() >= 2, "expected intermediate batches, got {:?}", batches); - assert_eq!(batches.join(""), "abc"); - } - - #[tokio::test] - async fn streaming_done_with_no_content_returns_empty_response_error() { - let mut server = mockito::Server::new_async().await; - // done:true but content is empty - let body = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":""},"done":true}"#; - let _mock = server - .mock("POST", "/api/chat") - .with_status(200) - .with_body(body) - .create_async() - .await; - - let (result, _) = run_streaming(&server.url(), 100).await; - assert!(matches!(result, Err(ApiError::EmptyResponse))); - } - - #[tokio::test] - async fn streaming_bad_status_returns_error() { - let mut server = mockito::Server::new_async().await; - let _mock = server - .mock("POST", "/api/chat") - .with_status(503) - .create_async() - .await; - - let (result, _) = run_streaming(&server.url(), 100).await; - assert!(matches!(result, Err(ApiError::BadStatus(503)))); - } - - #[tokio::test] - async fn streaming_malformed_json_line_is_skipped() { - let mut server = mockito::Server::new_async().await; - // A bad line in the middle should not abort processing - let body = format!( - "{}\nnot valid json\n{}", - r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hello"},"done":false}"#, - r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":" world"},"done":true}"#, - ); - let _mock = server - .mock("POST", "/api/chat") - .with_status(200) - .with_body(body) - .create_async() - .await; - - let (result, _) = run_streaming(&server.url(), 100).await; - // Should still accumulate valid tokens despite the bad line - assert_eq!(result.unwrap(), "Hello world"); - } + Ok((full_response, None)) } \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index f58a1cf..fb93705 100644 --- a/src/config.rs +++ b/src/config.rs @@ -7,7 +7,6 @@ pub struct Config { pub ui: UiConfig, pub colors: ColorConfig, pub ollama: OllamaConfig, - pub streaming: StreamingConfig, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -40,20 +39,6 @@ pub struct ColorConfig { pub struct OllamaConfig { pub url: String, pub timeout_seconds: u64, - /// Maximum number of conversation turns sent to the model (most recent N messages). - /// Keeps context within the model's limit. Set higher for longer memory. - pub max_context_messages: usize, - /// Optional system prompt prepended to every conversation. - /// Leave empty ("") to disable. RAG can override this at runtime via AppState. - pub system_prompt: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StreamingConfig { - /// Number of tokens to accumulate before flushing to the UI. - pub batch_size: usize, - /// Maximum milliseconds to wait before flushing a partial batch. - pub batch_timeout_ms: u64, } impl Default for Config { @@ -62,7 +47,6 @@ impl Default for Config { ui: UiConfig::default(), colors: ColorConfig::default(), ollama: OllamaConfig::default(), - streaming: StreamingConfig::default(), } } } @@ -104,17 +88,6 @@ impl Default for OllamaConfig { Self { url: "http://localhost:11434".to_string(), timeout_seconds: 120, - max_context_messages: 20, - system_prompt: String::new(), - } - } -} - -impl Default for StreamingConfig { - fn default() -> Self { - Self { - batch_size: 20, - batch_timeout_ms: 100, } } } @@ -155,62 +128,4 @@ impl Config { Ok(config_dir.join("config.toml")) } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn default_config_has_sensible_values() { - let cfg = Config::default(); - - // Ollama endpoint - assert_eq!(cfg.ollama.url, "http://localhost:11434"); - assert!(cfg.ollama.timeout_seconds > 0); - - // Context window: 0 would mean every request sends zero messages — bad default. - assert!(cfg.ollama.max_context_messages > 0, - "max_context_messages must be > 0 or no history is ever sent"); - - // Streaming - assert!(cfg.streaming.batch_size > 0); - assert!(cfg.streaming.batch_timeout_ms > 0); - - // System prompt empty by default (RAG / user sets it when needed) - assert!(cfg.ollama.system_prompt.is_empty()); - } - - #[test] - fn config_roundtrips_through_toml() { - // Serialize to TOML and back — verifies the serde field names are stable. - // If you rename a field in the struct but forget to update the serde attribute, - // existing config files on disk will silently lose that value. This catches it. - let original = Config::default(); - let toml_str = toml::to_string_pretty(&original).unwrap(); - let parsed: Config = toml::from_str(&toml_str).unwrap(); - - assert_eq!(parsed.ollama.url, original.ollama.url); - assert_eq!(parsed.ollama.max_context_messages, original.ollama.max_context_messages); - assert_eq!(parsed.streaming.batch_size, original.streaming.batch_size); - assert_eq!(parsed.streaming.batch_timeout_ms, original.streaming.batch_timeout_ms); - assert_eq!(parsed.ui.window_font_size, original.ui.window_font_size); - } - - #[test] - fn config_roundtrip_preserves_custom_values() { - let mut cfg = Config::default(); - cfg.ollama.url = "http://my-server:11434".to_string(); - cfg.ollama.max_context_messages = 50; - cfg.streaming.batch_size = 5; - cfg.ollama.system_prompt = "You are a helpful assistant.".to_string(); - - let toml_str = toml::to_string_pretty(&cfg).unwrap(); - let parsed: Config = toml::from_str(&toml_str).unwrap(); - - assert_eq!(parsed.ollama.url, "http://my-server:11434"); - assert_eq!(parsed.ollama.max_context_messages, 50); - assert_eq!(parsed.streaming.batch_size, 5); - assert_eq!(parsed.ollama.system_prompt, "You are a helpful assistant."); - } } \ No newline at end of file diff --git a/src/markdown_renderer.rs b/src/markdown_renderer.rs index fd504ee..5cab358 100644 --- a/src/markdown_renderer.rs +++ b/src/markdown_renderer.rs @@ -25,6 +25,7 @@ pub struct MarkdownRenderer { tags_setup: bool, // State for streaming think tag processing in_think_tag: bool, + think_buffer: String, } impl MarkdownRenderer { @@ -46,6 +47,7 @@ impl MarkdownRenderer { format_stack: Vec::new(), tags_setup: false, in_think_tag: false, + think_buffer: String::new(), } } @@ -181,19 +183,58 @@ impl MarkdownRenderer { } } - /// Process text for streaming, handling think tags in real-time. - /// - /// Delegates detection to [`parse_think_segments`] and handles GTK insertions per segment. + /// Process text for streaming, handling think tags in real-time fn process_streaming_text(&mut self, buffer: &TextBuffer, text: &str, iter: &mut TextIter) -> String { let mut result = String::new(); - for segment in parse_think_segments(text, &mut self.in_think_tag) { - match segment { - StreamSegment::Normal(s) => result.push_str(&s), - StreamSegment::ThinkStart => buffer.insert(iter, "\n💭 "), - StreamSegment::Think(s) => buffer.insert_with_tags(iter, &s, &[&self.think_tag]), - StreamSegment::ThinkEnd => buffer.insert(iter, "\n\n"), + let mut remaining = text; + + while !remaining.is_empty() { + if self.in_think_tag { + // We're currently inside a think tag, look for closing tag + if let Some(end_pos) = remaining.find("") { + // Found closing tag - stream the remaining think content + let final_think_content = &remaining[..end_pos]; + if !final_think_content.is_empty() { + buffer.insert_with_tags(iter, final_think_content, &[&self.think_tag]); + } + + // Close the think section + buffer.insert(iter, "\n\n"); + + // Reset think state + self.in_think_tag = false; + self.think_buffer.clear(); + + // Continue with text after closing tag + remaining = &remaining[end_pos + 8..]; // 8 = "".len() + } else { + // No closing tag yet, stream the think content as it arrives + if !remaining.is_empty() { + buffer.insert_with_tags(iter, remaining, &[&self.think_tag]); + } + break; // Wait for more streaming content + } + } else { + // Not in think tag, look for opening tag + if let Some(start_pos) = remaining.find("") { + // Add content before think tag to result for normal processing + result.push_str(&remaining[..start_pos]); + + // Start think mode and show the think indicator + self.in_think_tag = true; + self.think_buffer.clear(); + buffer.insert(iter, "\n💭 "); + + // Continue with content after opening tag + remaining = &remaining[start_pos + 7..]; // 7 = "".len() + } else { + // No think tag found, add all remaining text to result + result.push_str(remaining); + break; + } } } + result } @@ -345,68 +386,6 @@ impl MarkdownRenderer { } } -/// A parsed segment from streaming think-tag processing. -/// -/// Separating the pure detection logic (see [`parse_think_segments`]) from GTK mutations lets us -/// unit-test the state machine without a live display session. -#[derive(Debug, PartialEq)] -enum StreamSegment { - /// Plain text that passes through the markdown renderer. - Normal(String), - /// The `` opening boundary was seen; the GTK layer emits an indicator. - ThinkStart, - /// Content inside a think block, rendered with the think-tag style. - Think(String), - /// The `` closing boundary was seen; the GTK layer emits a separator. - ThinkEnd, -} - -/// Parse `text` into typed [`StreamSegment`]s, updating the in-flight `in_think` cursor. -/// -/// Streaming-safe: a think block may open in one call and close in a later call. -fn parse_think_segments(text: &str, in_think: &mut bool) -> Vec { - let mut segments = Vec::new(); - let mut remaining = text; - - while !remaining.is_empty() { - if *in_think { - match remaining.find("") { - Some(end_pos) => { - let content = &remaining[..end_pos]; - if !content.is_empty() { - segments.push(StreamSegment::Think(content.to_string())); - } - segments.push(StreamSegment::ThinkEnd); - *in_think = false; - remaining = &remaining[end_pos + 8..]; // skip "" - } - None => { - segments.push(StreamSegment::Think(remaining.to_string())); - break; - } - } - } else { - match remaining.find("") { - Some(start_pos) => { - let before = &remaining[..start_pos]; - if !before.is_empty() { - segments.push(StreamSegment::Normal(before.to_string())); - } - segments.push(StreamSegment::ThinkStart); - *in_think = true; - remaining = &remaining[start_pos + 7..]; // skip "" - } - None => { - segments.push(StreamSegment::Normal(remaining.to_string())); - break; - } - } - } - } - - segments -} - /// Helper function to parse color strings (hex format) into RGBA fn parse_color(color_str: &str) -> Result> { let color_str = color_str.trim_start_matches('#'); @@ -420,131 +399,4 @@ fn parse_color(color_str: &str) -> Resultthinking after", &mut in_think); - assert_eq!(segs, vec![ - StreamSegment::Normal("before ".into()), - StreamSegment::ThinkStart, - StreamSegment::Think("thinking".into()), - StreamSegment::ThinkEnd, - StreamSegment::Normal(" after".into()), - ]); - assert!(!in_think); - } - - #[test] - fn unclosed_think_tag_leaves_in_think_true() { - let mut in_think = false; - let segs = parse_think_segments("start partial", &mut in_think); - assert_eq!(segs, vec![ - StreamSegment::Normal("start ".into()), - StreamSegment::ThinkStart, - StreamSegment::Think("partial".into()), - ]); - assert!(in_think); - } - - #[test] - fn closing_tag_in_second_call_closes_correctly() { - let mut in_think = true; // simulate carrying over from previous call - let segs = parse_think_segments("rest normal", &mut in_think); - assert_eq!(segs, vec![ - StreamSegment::Think("rest".into()), - StreamSegment::ThinkEnd, - StreamSegment::Normal(" normal".into()), - ]); - assert!(!in_think); - } - - #[test] - fn continuation_while_in_think_produces_think_segment() { - let mut in_think = true; - let segs = parse_think_segments("more thinking...", &mut in_think); - assert_eq!(segs, vec![StreamSegment::Think("more thinking...".into())]); - assert!(in_think); - } - - #[test] - fn empty_think_block_produces_start_and_end_only() { - let mut in_think = false; - let segs = parse_think_segments("", &mut in_think); - assert_eq!(segs, vec![ - StreamSegment::ThinkStart, - StreamSegment::ThinkEnd, - ]); - assert!(!in_think); - } - - #[test] - fn think_block_at_very_start() { - let mut in_think = false; - let segs = parse_think_segments("reasoninganswer", &mut in_think); - assert_eq!(segs, vec![ - StreamSegment::ThinkStart, - StreamSegment::Think("reasoning".into()), - StreamSegment::ThinkEnd, - StreamSegment::Normal("answer".into()), - ]); - assert!(!in_think); - } } \ No newline at end of file diff --git a/src/state.rs b/src/state.rs index c41e113..a9d2e23 100644 --- a/src/state.rs +++ b/src/state.rs @@ -6,23 +6,29 @@ use crate::config::Config; pub type SharedState = Rc>; -/// Application-level errors. Uses `thiserror` so each variant gets a clear, typed -/// message without boilerplate. Callers can match on the variant to handle errors -/// differently (e.g. show a dialog for Config vs. a status-bar message for Api). -#[derive(Debug, thiserror::Error)] +#[derive(Debug)] pub enum AppError { - #[error("API error: {0}")] Api(String), - #[error("UI error: {0}")] Ui(String), - #[error("State error: {0}")] State(String), - #[error("Validation error: {0}")] Validation(String), - #[error("Config error: {0}")] Config(String), } +impl std::fmt::Display for AppError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AppError::Api(msg) => write!(f, "API Error: {}", msg), + AppError::Ui(msg) => write!(f, "UI Error: {}", msg), + AppError::State(msg) => write!(f, "State Error: {}", msg), + AppError::Validation(msg) => write!(f, "Validation Error: {}", msg), + AppError::Config(msg) => write!(f, "Config Error: {}", msg), + } + } +} + +impl std::error::Error for AppError {} + pub type AppResult = Result; #[derive(Debug, Clone, Copy, PartialEq)] @@ -39,9 +45,6 @@ pub struct AppState { pub current_task: Option>, pub selected_model: Option, pub status_message: String, - /// System prompt prepended to every request. Initialized from config but can be - /// overridden at runtime (e.g. by a RAG pipeline to inject retrieved context). - pub system_prompt: Option, pub config: Config, } @@ -51,13 +54,7 @@ impl Default for AppState { eprintln!("Warning: Failed to load config, using defaults: {}", e); Config::default() }); - - let system_prompt = if config.ollama.system_prompt.is_empty() { - None - } else { - Some(config.ollama.system_prompt.clone()) - }; - + Self { conversation: Vec::new(), ollama_url: config.ollama.url.clone(), @@ -66,7 +63,6 @@ impl Default for AppState { current_task: None, selected_model: None, status_message: "Ready".to_string(), - system_prompt, config, } } @@ -107,109 +103,5 @@ impl AppState { self.set_generating(false); self.set_status("Generation stopped".to_string()); } - -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_state() -> AppState { - AppState { - conversation: Vec::new(), - ollama_url: "http://localhost:11434".into(), - is_generating: false, - button_state: ButtonState::Send, - current_task: None, - selected_model: None, - status_message: "Ready".into(), - system_prompt: None, - config: Config::default(), - } - } - - #[test] - fn set_generating_true_sets_stop_state() { - let mut state = make_state(); - state.set_generating(true); - assert!(state.is_generating); - assert_eq!(state.button_state, ButtonState::Stop); - } - - #[test] - fn set_generating_false_sets_send_state() { - let mut state = make_state(); - state.is_generating = true; - state.button_state = ButtonState::Stop; - state.set_generating(false); - assert!(!state.is_generating); - assert_eq!(state.button_state, ButtonState::Send); - } - - #[test] - fn add_user_message_appends_with_correct_role() { - let mut state = make_state(); - state.add_user_message("hello".into()); - assert_eq!(state.conversation.len(), 1); - assert_eq!(state.conversation[0].role, "user"); - assert_eq!(state.conversation[0].content, "hello"); - } - - #[test] - fn add_assistant_message_appends_with_correct_role() { - let mut state = make_state(); - state.add_assistant_message("hi there".into()); - assert_eq!(state.conversation.len(), 1); - assert_eq!(state.conversation[0].role, "assistant"); - assert_eq!(state.conversation[0].content, "hi there"); - } - - #[test] - fn conversation_preserves_insertion_order() { - let mut state = make_state(); - state.add_user_message("first".into()); - state.add_assistant_message("second".into()); - state.add_user_message("third".into()); - assert_eq!(state.conversation.len(), 3); - assert_eq!(state.conversation[0].role, "user"); - assert_eq!(state.conversation[1].role, "assistant"); - assert_eq!(state.conversation[2].role, "user"); - } - - #[test] - fn set_status_updates_message() { - let mut state = make_state(); - state.set_status("Loading models...".into()); - assert_eq!(state.status_message, "Loading models..."); - } - - #[test] - fn abort_current_task_without_task_resets_state() { - let mut state = make_state(); - state.is_generating = true; - state.button_state = ButtonState::Stop; - state.abort_current_task(); - assert!(!state.is_generating); - assert_eq!(state.button_state, ButtonState::Send); - assert_eq!(state.status_message, "Generation stopped"); - assert!(state.current_task.is_none()); - } - - #[tokio::test] - async fn abort_current_task_aborts_running_task() { - let mut state = make_state(); - // Spawn a task that sleeps forever so we can verify it gets aborted - let handle = tokio::spawn(async { - tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await; - }); - state.current_task = Some(handle); - state.is_generating = true; - state.button_state = ButtonState::Stop; - - state.abort_current_task(); - - assert!(state.current_task.is_none()); - assert!(!state.is_generating); - assert_eq!(state.status_message, "Generation stopped"); - } + } \ No newline at end of file diff --git a/src/types.rs b/src/types.rs index 0da4dbe..1a16ef1 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,7 +1,5 @@ use serde::{Deserialize, Serialize}; -// --- Types --- - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub role: String, @@ -38,47 +36,4 @@ pub struct ModelInfo { #[derive(Debug, Serialize, Deserialize)] pub struct ModelsResponse { pub models: Vec, -} - -#[cfg(test)] -mod tests { - use super::*; - - // `cargo test` will run these. This is the standard Rust pattern: - // put tests in a `#[cfg(test)]` module so they're compiled only when testing. - - #[test] - fn chat_message_serializes_and_deserializes() { - // Verify that serde roundtrips correctly — if you ever rename a field - // or change its type, this test catches it before it reaches the API. - let msg = ChatMessage { - role: "user".to_string(), - content: "Hello, world!".to_string(), - }; - let json = serde_json::to_string(&msg).unwrap(); - let decoded: ChatMessage = serde_json::from_str(&json).unwrap(); - assert_eq!(decoded.role, msg.role); - assert_eq!(decoded.content, msg.content); - } - - #[test] - fn chat_request_includes_stream_flag() { - let req = ChatRequest { - model: "llama3".to_string(), - messages: vec![], - stream: true, - }; - let json = serde_json::to_string(&req).unwrap(); - // The Ollama API requires `"stream": true` in the body. - assert!(json.contains("\"stream\":true")); - } - - #[test] - fn stream_response_parses_done_flag() { - // Real payload shape from the Ollama streaming API. - let raw = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hi"},"done":true}"#; - let resp: StreamResponse = serde_json::from_str(raw).unwrap(); - assert!(resp.done); - assert_eq!(resp.message.content, "Hi"); - } } \ No newline at end of file diff --git a/src/ui/chat.rs b/src/ui/chat.rs index 3f84176..84a2824 100644 --- a/src/ui/chat.rs +++ b/src/ui/chat.rs @@ -76,17 +76,16 @@ impl ChatView { } // Add sender label with bold formatting + let sender_tag = gtk4::TextTag::new(Some("sender")); + sender_tag.set_weight(700); + sender_tag.set_property("pixels-below-lines", 4); + + // Add the sender tag to the buffer's tag table if it's not already there let tag_table = self.text_buffer.tag_table(); - let sender_tag = if let Some(existing) = tag_table.lookup("sender") { - existing - } else { - let tag = gtk4::TextTag::new(Some("sender")); - tag.set_weight(700); - tag.set_property("pixels-below-lines", 4); - tag_table.add(&tag); - tag - }; - + if tag_table.lookup("sender").is_none() { + tag_table.add(&sender_tag); + } + self.text_buffer.insert_with_tags(&mut end_iter, &format!("{}:\n", sender), &[&sender_tag]); end_iter = self.text_buffer.end_iter(); @@ -133,7 +132,10 @@ impl ChatView { // Delete existing content from mark to end self.text_buffer.delete(&mut start_iter, &mut end_iter); - + + // Get a fresh iterator at the mark position after deletion + let _insert_iter = self.text_buffer.iter_at_mark(mark); + // Render markdown directly to the main buffer // We use a separate method to avoid conflicts with the borrow checker self.render_markdown_at_mark(mark, accumulated_content, config); diff --git a/src/ui/handlers.rs b/src/ui/handlers.rs index 8603f64..457caa6 100644 --- a/src/ui/handlers.rs +++ b/src/ui/handlers.rs @@ -113,15 +113,26 @@ fn handle_stop_click( } fn set_generating_state( - shared_state: &SharedState, - controls: &ControlsArea, - button: >k4::Button, - generating: bool, + shared_state: &SharedState, + controls: &ControlsArea, + button: >k4::Button, + generating: bool ) { - let status = if generating { "Assistant is typing..." } else { "Ready" }; - shared_state.borrow_mut().set_generating(generating); + { + let mut state = shared_state.borrow_mut(); + state.set_generating(generating); + state.set_status(if generating { + "Assistant is typing...".to_string() + } else { + "Ready".to_string() + }); + } update_button_state(shared_state, button); - controls.set_status(status); + controls.set_status(if generating { + "Assistant is typing..." + } else { + "Ready" + }); } fn update_button_state(shared_state: &SharedState, button: >k4::Button) { @@ -166,23 +177,12 @@ fn start_streaming_task( model: String, ) { let (content_sender, content_receiver) = async_channel::bounded::(100); - let (result_sender, result_receiver) = async_channel::bounded::>(1); + let (result_sender, result_receiver) = async_channel::bounded(1); - // Extract data from shared state for API call. - // Only send the most recent `max_context_messages` turns to stay within the model's - // context window. Prepend the system prompt (if set) as the first message. - let (messages, ollama_url, batch_size, batch_timeout_ms) = { + // Extract data from shared state for API call + let (conversation, ollama_url) = { let state = shared_state.borrow(); - let max = state.config.ollama.max_context_messages; - let skip = state.conversation.len().saturating_sub(max); - let mut msgs: Vec<_> = state.conversation[skip..].to_vec(); - if let Some(ref prompt) = state.system_prompt { - msgs.insert(0, crate::types::ChatMessage { - role: "system".to_string(), - content: prompt.clone(), - }); - } - (msgs, state.ollama_url.clone(), state.config.streaming.batch_size, state.config.streaming.batch_timeout_ms) + (state.conversation.clone(), state.ollama_url.clone()) }; // Spawn API task @@ -190,10 +190,8 @@ fn start_streaming_task( let result = api::send_chat_request_streaming( &ollama_url, &model, - messages, + &std::sync::Arc::new(std::sync::Mutex::new(conversation)), content_sender, - batch_size, - batch_timeout_ms, ).await; let _ = result_sender.send(result).await; }); @@ -218,7 +216,7 @@ fn setup_streaming_handlers( controls: &ControlsArea, button: >k4::Button, content_receiver: async_channel::Receiver, - result_receiver: async_channel::Receiver>, + result_receiver: async_channel::Receiver), Box>>, ) { // Setup UI structure for streaming let mut end_iter = chat_view.buffer().end_iter(); @@ -243,6 +241,7 @@ fn setup_streaming_handlers( accumulated_content.push_str(&content_batch); let config = shared_state_streaming.borrow().config.clone(); chat_view_content.update_streaming_markdown(&response_mark_clone, &accumulated_content, &config); + chat_view_content.scroll_to_bottom(); } }); @@ -259,10 +258,10 @@ fn setup_streaming_handlers( Ok(response_text) => { // Apply final markdown formatting let config = shared_state_final.borrow().config.clone(); - chat_view_final.insert_formatted_at_mark(&response_mark, &response_text, &config); - + chat_view_final.insert_formatted_at_mark(&response_mark, &response_text.0, &config); + // Update conversation state - shared_state_final.borrow_mut().add_assistant_message(response_text); + shared_state_final.borrow_mut().add_assistant_message(response_text.0); set_generating_state(&shared_state_final, &controls_final, &button_final, false); } Err(e) => {