diff --git a/Cargo.toml b/Cargo.toml index 74515aa..486db13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,7 @@ futures-util = "0.3" async-channel = "2.3" pulldown-cmark = { version = "0.13.0", default-features = false, features = ["html"] } thiserror = "2" + +[dev-dependencies] +mockito = "1" +tempfile = "3" diff --git a/src/api.rs b/src/api.rs index 4e40901..5188e91 100644 --- a/src/api.rs +++ b/src/api.rs @@ -132,4 +132,191 @@ pub async fn send_chat_request_streaming( } Ok(full_response) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── fetch_models ───────────────────────────────────────────────────────── + + #[tokio::test] + async fn fetch_models_returns_model_list() { + let mut server = mockito::Server::new_async().await; + let _mock = server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"models":[{"name":"llama3","modified_at":"2024-01-01T00:00:00Z","size":4000000}]}"#) + .create_async() + .await; + + let models = fetch_models(&server.url()).await.unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "llama3"); + } + + #[tokio::test] + async fn fetch_models_bad_status_returns_error() { + let mut server = mockito::Server::new_async().await; + let _mock = server + .mock("GET", "/api/tags") + .with_status(500) + .create_async() + .await; + + let err = fetch_models(&server.url()).await.unwrap_err(); + // reqwest treats non-success as an error only if we explicitly check; + // here fetch_models passes the status through response.json() which will + // fail because body is empty — so we get an Http or Parse error. + // The important thing: it is an error, not a success. + assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_) | ApiError::BadStatus(_))); + } + + #[tokio::test] + async fn fetch_models_bad_json_returns_parse_error() { + let mut server = mockito::Server::new_async().await; + let _mock = server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body("not json") + .create_async() + .await; + + let err = fetch_models(&server.url()).await.unwrap_err(); + assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_))); + } + + // ── send_chat_request_streaming ────────────────────────────────────────── + + fn ndjson_lines(tokens: &[(&str, bool)]) -> String { + tokens + .iter() + .map(|(content, done)| { + format!( + r#"{{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{{"role":"assistant","content":"{}"}},"done":{}}}"#, + content, done + ) + }) + .collect::>() + .join("\n") + } + + async fn run_streaming(server_url: &str, batch_size: usize) -> (Result, Vec) { + let messages = vec![ChatMessage { + role: "user".to_string(), + content: "hi".to_string(), + }]; + let (tx, rx) = async_channel::unbounded(); + let result = send_chat_request_streaming( + server_url, "llama3", messages, tx, batch_size, 5000, + ) + .await; + + let mut batches = Vec::new(); + while let Ok(batch) = rx.try_recv() { + batches.push(batch); + } + (result, batches) + } + + #[tokio::test] + async fn streaming_single_token_returns_full_response() { + let mut server = mockito::Server::new_async().await; + let body = ndjson_lines(&[("Hello", true)]); + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, _batches) = run_streaming(&server.url(), 100).await; + assert_eq!(result.unwrap(), "Hello"); + } + + #[tokio::test] + async fn streaming_multi_token_accumulates_full_response() { + let mut server = mockito::Server::new_async().await; + let body = ndjson_lines(&[("Hello", false), (" world", true)]); + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, _) = run_streaming(&server.url(), 100).await; + assert_eq!(result.unwrap(), "Hello world"); + } + + #[tokio::test] + async fn streaming_batch_size_flushes_intermediate_batches() { + let mut server = mockito::Server::new_async().await; + // 3 tokens, batch_size=2 → first batch sent after 2 tokens, second after done + let body = ndjson_lines(&[("a", false), ("b", false), ("c", true)]); + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, batches) = run_streaming(&server.url(), 2).await; + assert_eq!(result.unwrap(), "abc"); + // We should have received at least 2 channel messages (one mid-stream, one final) + assert!(batches.len() >= 2, "expected intermediate batches, got {:?}", batches); + assert_eq!(batches.join(""), "abc"); + } + + #[tokio::test] + async fn streaming_done_with_no_content_returns_empty_response_error() { + let mut server = mockito::Server::new_async().await; + // done:true but content is empty + let body = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":""},"done":true}"#; + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, _) = run_streaming(&server.url(), 100).await; + assert!(matches!(result, Err(ApiError::EmptyResponse))); + } + + #[tokio::test] + async fn streaming_bad_status_returns_error() { + let mut server = mockito::Server::new_async().await; + let _mock = server + .mock("POST", "/api/chat") + .with_status(503) + .create_async() + .await; + + let (result, _) = run_streaming(&server.url(), 100).await; + assert!(matches!(result, Err(ApiError::BadStatus(503)))); + } + + #[tokio::test] + async fn streaming_malformed_json_line_is_skipped() { + let mut server = mockito::Server::new_async().await; + // A bad line in the middle should not abort processing + let body = format!( + "{}\nnot valid json\n{}", + r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hello"},"done":false}"#, + r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":" world"},"done":true}"#, + ); + let _mock = server + .mock("POST", "/api/chat") + .with_status(200) + .with_body(body) + .create_async() + .await; + + let (result, _) = run_streaming(&server.url(), 100).await; + // Should still accumulate valid tokens despite the bad line + assert_eq!(result.unwrap(), "Hello world"); + } } \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 1a8656d..f58a1cf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -155,4 +155,62 @@ impl Config { Ok(config_dir.join("config.toml")) } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_config_has_sensible_values() { + let cfg = Config::default(); + + // Ollama endpoint + assert_eq!(cfg.ollama.url, "http://localhost:11434"); + assert!(cfg.ollama.timeout_seconds > 0); + + // Context window: 0 would mean every request sends zero messages — bad default. + assert!(cfg.ollama.max_context_messages > 0, + "max_context_messages must be > 0 or no history is ever sent"); + + // Streaming + assert!(cfg.streaming.batch_size > 0); + assert!(cfg.streaming.batch_timeout_ms > 0); + + // System prompt empty by default (RAG / user sets it when needed) + assert!(cfg.ollama.system_prompt.is_empty()); + } + + #[test] + fn config_roundtrips_through_toml() { + // Serialize to TOML and back — verifies the serde field names are stable. + // If you rename a field in the struct but forget to update the serde attribute, + // existing config files on disk will silently lose that value. This catches it. + let original = Config::default(); + let toml_str = toml::to_string_pretty(&original).unwrap(); + let parsed: Config = toml::from_str(&toml_str).unwrap(); + + assert_eq!(parsed.ollama.url, original.ollama.url); + assert_eq!(parsed.ollama.max_context_messages, original.ollama.max_context_messages); + assert_eq!(parsed.streaming.batch_size, original.streaming.batch_size); + assert_eq!(parsed.streaming.batch_timeout_ms, original.streaming.batch_timeout_ms); + assert_eq!(parsed.ui.window_font_size, original.ui.window_font_size); + } + + #[test] + fn config_roundtrip_preserves_custom_values() { + let mut cfg = Config::default(); + cfg.ollama.url = "http://my-server:11434".to_string(); + cfg.ollama.max_context_messages = 50; + cfg.streaming.batch_size = 5; + cfg.ollama.system_prompt = "You are a helpful assistant.".to_string(); + + let toml_str = toml::to_string_pretty(&cfg).unwrap(); + let parsed: Config = toml::from_str(&toml_str).unwrap(); + + assert_eq!(parsed.ollama.url, "http://my-server:11434"); + assert_eq!(parsed.ollama.max_context_messages, 50); + assert_eq!(parsed.streaming.batch_size, 5); + assert_eq!(parsed.ollama.system_prompt, "You are a helpful assistant."); + } } \ No newline at end of file diff --git a/src/markdown_renderer.rs b/src/markdown_renderer.rs index e252cd3..fd504ee 100644 --- a/src/markdown_renderer.rs +++ b/src/markdown_renderer.rs @@ -181,56 +181,19 @@ impl MarkdownRenderer { } } - /// Process text for streaming, handling think tags in real-time + /// Process text for streaming, handling think tags in real-time. + /// + /// Delegates detection to [`parse_think_segments`] and handles GTK insertions per segment. fn process_streaming_text(&mut self, buffer: &TextBuffer, text: &str, iter: &mut TextIter) -> String { let mut result = String::new(); - let mut remaining = text; - - while !remaining.is_empty() { - if self.in_think_tag { - // We're currently inside a think tag, look for closing tag - if let Some(end_pos) = remaining.find("") { - // Found closing tag - stream the remaining think content - let final_think_content = &remaining[..end_pos]; - if !final_think_content.is_empty() { - buffer.insert_with_tags(iter, final_think_content, &[&self.think_tag]); - } - - // Close the think section - buffer.insert(iter, "\n\n"); - - // Reset think state - self.in_think_tag = false; - - // Continue with text after closing tag - remaining = &remaining[end_pos + 8..]; // 8 = "".len() - } else { - // No closing tag yet, stream the think content as it arrives - if !remaining.is_empty() { - buffer.insert_with_tags(iter, remaining, &[&self.think_tag]); - } - break; // Wait for more streaming content - } - } else { - // Not in think tag, look for opening tag - if let Some(start_pos) = remaining.find("") { - // Add content before think tag to result for normal processing - result.push_str(&remaining[..start_pos]); - - // Start think mode and show the think indicator - self.in_think_tag = true; - buffer.insert(iter, "\n💭 "); - - // Continue with content after opening tag - remaining = &remaining[start_pos + 7..]; // 7 = "".len() - } else { - // No think tag found, add all remaining text to result - result.push_str(remaining); - break; - } + for segment in parse_think_segments(text, &mut self.in_think_tag) { + match segment { + StreamSegment::Normal(s) => result.push_str(&s), + StreamSegment::ThinkStart => buffer.insert(iter, "\n💭 "), + StreamSegment::Think(s) => buffer.insert_with_tags(iter, &s, &[&self.think_tag]), + StreamSegment::ThinkEnd => buffer.insert(iter, "\n\n"), } } - result } @@ -382,6 +345,68 @@ impl MarkdownRenderer { } } +/// A parsed segment from streaming think-tag processing. +/// +/// Separating the pure detection logic (see [`parse_think_segments`]) from GTK mutations lets us +/// unit-test the state machine without a live display session. +#[derive(Debug, PartialEq)] +enum StreamSegment { + /// Plain text that passes through the markdown renderer. + Normal(String), + /// The `` opening boundary was seen; the GTK layer emits an indicator. + ThinkStart, + /// Content inside a think block, rendered with the think-tag style. + Think(String), + /// The `` closing boundary was seen; the GTK layer emits a separator. + ThinkEnd, +} + +/// Parse `text` into typed [`StreamSegment`]s, updating the in-flight `in_think` cursor. +/// +/// Streaming-safe: a think block may open in one call and close in a later call. +fn parse_think_segments(text: &str, in_think: &mut bool) -> Vec { + let mut segments = Vec::new(); + let mut remaining = text; + + while !remaining.is_empty() { + if *in_think { + match remaining.find("") { + Some(end_pos) => { + let content = &remaining[..end_pos]; + if !content.is_empty() { + segments.push(StreamSegment::Think(content.to_string())); + } + segments.push(StreamSegment::ThinkEnd); + *in_think = false; + remaining = &remaining[end_pos + 8..]; // skip "" + } + None => { + segments.push(StreamSegment::Think(remaining.to_string())); + break; + } + } + } else { + match remaining.find("") { + Some(start_pos) => { + let before = &remaining[..start_pos]; + if !before.is_empty() { + segments.push(StreamSegment::Normal(before.to_string())); + } + segments.push(StreamSegment::ThinkStart); + *in_think = true; + remaining = &remaining[start_pos + 7..]; // skip "" + } + None => { + segments.push(StreamSegment::Normal(remaining.to_string())); + break; + } + } + } + } + + segments +} + /// Helper function to parse color strings (hex format) into RGBA fn parse_color(color_str: &str) -> Result> { let color_str = color_str.trim_start_matches('#'); @@ -395,4 +420,131 @@ fn parse_color(color_str: &str) -> Resultthinking after", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::Normal("before ".into()), + StreamSegment::ThinkStart, + StreamSegment::Think("thinking".into()), + StreamSegment::ThinkEnd, + StreamSegment::Normal(" after".into()), + ]); + assert!(!in_think); + } + + #[test] + fn unclosed_think_tag_leaves_in_think_true() { + let mut in_think = false; + let segs = parse_think_segments("start partial", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::Normal("start ".into()), + StreamSegment::ThinkStart, + StreamSegment::Think("partial".into()), + ]); + assert!(in_think); + } + + #[test] + fn closing_tag_in_second_call_closes_correctly() { + let mut in_think = true; // simulate carrying over from previous call + let segs = parse_think_segments("rest normal", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::Think("rest".into()), + StreamSegment::ThinkEnd, + StreamSegment::Normal(" normal".into()), + ]); + assert!(!in_think); + } + + #[test] + fn continuation_while_in_think_produces_think_segment() { + let mut in_think = true; + let segs = parse_think_segments("more thinking...", &mut in_think); + assert_eq!(segs, vec![StreamSegment::Think("more thinking...".into())]); + assert!(in_think); + } + + #[test] + fn empty_think_block_produces_start_and_end_only() { + let mut in_think = false; + let segs = parse_think_segments("", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::ThinkStart, + StreamSegment::ThinkEnd, + ]); + assert!(!in_think); + } + + #[test] + fn think_block_at_very_start() { + let mut in_think = false; + let segs = parse_think_segments("reasoninganswer", &mut in_think); + assert_eq!(segs, vec![ + StreamSegment::ThinkStart, + StreamSegment::Think("reasoning".into()), + StreamSegment::ThinkEnd, + StreamSegment::Normal("answer".into()), + ]); + assert!(!in_think); + } } \ No newline at end of file diff --git a/src/state.rs b/src/state.rs index 0a470e4..c41e113 100644 --- a/src/state.rs +++ b/src/state.rs @@ -107,5 +107,109 @@ impl AppState { self.set_generating(false); self.set_status("Generation stopped".to_string()); } - + +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_state() -> AppState { + AppState { + conversation: Vec::new(), + ollama_url: "http://localhost:11434".into(), + is_generating: false, + button_state: ButtonState::Send, + current_task: None, + selected_model: None, + status_message: "Ready".into(), + system_prompt: None, + config: Config::default(), + } + } + + #[test] + fn set_generating_true_sets_stop_state() { + let mut state = make_state(); + state.set_generating(true); + assert!(state.is_generating); + assert_eq!(state.button_state, ButtonState::Stop); + } + + #[test] + fn set_generating_false_sets_send_state() { + let mut state = make_state(); + state.is_generating = true; + state.button_state = ButtonState::Stop; + state.set_generating(false); + assert!(!state.is_generating); + assert_eq!(state.button_state, ButtonState::Send); + } + + #[test] + fn add_user_message_appends_with_correct_role() { + let mut state = make_state(); + state.add_user_message("hello".into()); + assert_eq!(state.conversation.len(), 1); + assert_eq!(state.conversation[0].role, "user"); + assert_eq!(state.conversation[0].content, "hello"); + } + + #[test] + fn add_assistant_message_appends_with_correct_role() { + let mut state = make_state(); + state.add_assistant_message("hi there".into()); + assert_eq!(state.conversation.len(), 1); + assert_eq!(state.conversation[0].role, "assistant"); + assert_eq!(state.conversation[0].content, "hi there"); + } + + #[test] + fn conversation_preserves_insertion_order() { + let mut state = make_state(); + state.add_user_message("first".into()); + state.add_assistant_message("second".into()); + state.add_user_message("third".into()); + assert_eq!(state.conversation.len(), 3); + assert_eq!(state.conversation[0].role, "user"); + assert_eq!(state.conversation[1].role, "assistant"); + assert_eq!(state.conversation[2].role, "user"); + } + + #[test] + fn set_status_updates_message() { + let mut state = make_state(); + state.set_status("Loading models...".into()); + assert_eq!(state.status_message, "Loading models..."); + } + + #[test] + fn abort_current_task_without_task_resets_state() { + let mut state = make_state(); + state.is_generating = true; + state.button_state = ButtonState::Stop; + state.abort_current_task(); + assert!(!state.is_generating); + assert_eq!(state.button_state, ButtonState::Send); + assert_eq!(state.status_message, "Generation stopped"); + assert!(state.current_task.is_none()); + } + + #[tokio::test] + async fn abort_current_task_aborts_running_task() { + let mut state = make_state(); + // Spawn a task that sleeps forever so we can verify it gets aborted + let handle = tokio::spawn(async { + tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await; + }); + state.current_task = Some(handle); + state.is_generating = true; + state.button_state = ButtonState::Stop; + + state.abort_current_task(); + + assert!(state.current_task.is_none()); + assert!(!state.is_generating); + assert_eq!(state.status_message, "Generation stopped"); + } } \ No newline at end of file diff --git a/src/types.rs b/src/types.rs index 1a16ef1..0da4dbe 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +// --- Types --- + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub role: String, @@ -36,4 +38,47 @@ pub struct ModelInfo { #[derive(Debug, Serialize, Deserialize)] pub struct ModelsResponse { pub models: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + // `cargo test` will run these. This is the standard Rust pattern: + // put tests in a `#[cfg(test)]` module so they're compiled only when testing. + + #[test] + fn chat_message_serializes_and_deserializes() { + // Verify that serde roundtrips correctly — if you ever rename a field + // or change its type, this test catches it before it reaches the API. + let msg = ChatMessage { + role: "user".to_string(), + content: "Hello, world!".to_string(), + }; + let json = serde_json::to_string(&msg).unwrap(); + let decoded: ChatMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(decoded.role, msg.role); + assert_eq!(decoded.content, msg.content); + } + + #[test] + fn chat_request_includes_stream_flag() { + let req = ChatRequest { + model: "llama3".to_string(), + messages: vec![], + stream: true, + }; + let json = serde_json::to_string(&req).unwrap(); + // The Ollama API requires `"stream": true` in the body. + assert!(json.contains("\"stream\":true")); + } + + #[test] + fn stream_response_parses_done_flag() { + // Real payload shape from the Ollama streaming API. + let raw = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hi"},"done":true}"#; + let resp: StreamResponse = serde_json::from_str(raw).unwrap(); + assert!(resp.done); + assert_eq!(resp.message.content, "Hi"); + } } \ No newline at end of file diff --git a/src/ui/handlers.rs b/src/ui/handlers.rs index 0dc68c9..8603f64 100644 --- a/src/ui/handlers.rs +++ b/src/ui/handlers.rs @@ -243,7 +243,6 @@ fn setup_streaming_handlers( accumulated_content.push_str(&content_batch); let config = shared_state_streaming.borrow().config.clone(); chat_view_content.update_streaming_markdown(&response_mark_clone, &accumulated_content, &config); - chat_view_content.scroll_to_bottom(); } });