tests for existing flow
This commit is contained in:
parent
2a75c9064f
commit
a6705e7d95
7 changed files with 597 additions and 48 deletions
|
|
@ -16,3 +16,7 @@ futures-util = "0.3"
|
||||||
async-channel = "2.3"
|
async-channel = "2.3"
|
||||||
pulldown-cmark = { version = "0.13.0", default-features = false, features = ["html"] }
|
pulldown-cmark = { version = "0.13.0", default-features = false, features = ["html"] }
|
||||||
thiserror = "2"
|
thiserror = "2"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
mockito = "1"
|
||||||
|
tempfile = "3"
|
||||||
|
|
|
||||||
187
src/api.rs
187
src/api.rs
|
|
@ -133,3 +133,190 @@ pub async fn send_chat_request_streaming(
|
||||||
|
|
||||||
Ok(full_response)
|
Ok(full_response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
// ── fetch_models ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fetch_models_returns_model_list() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/api/tags")
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "application/json")
|
||||||
|
.with_body(r#"{"models":[{"name":"llama3","modified_at":"2024-01-01T00:00:00Z","size":4000000}]}"#)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let models = fetch_models(&server.url()).await.unwrap();
|
||||||
|
assert_eq!(models.len(), 1);
|
||||||
|
assert_eq!(models[0].name, "llama3");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fetch_models_bad_status_returns_error() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/api/tags")
|
||||||
|
.with_status(500)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let err = fetch_models(&server.url()).await.unwrap_err();
|
||||||
|
// reqwest treats non-success as an error only if we explicitly check;
|
||||||
|
// here fetch_models passes the status through response.json() which will
|
||||||
|
// fail because body is empty — so we get an Http or Parse error.
|
||||||
|
// The important thing: it is an error, not a success.
|
||||||
|
assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_) | ApiError::BadStatus(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fetch_models_bad_json_returns_parse_error() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/api/tags")
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "application/json")
|
||||||
|
.with_body("not json")
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let err = fetch_models(&server.url()).await.unwrap_err();
|
||||||
|
assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── send_chat_request_streaming ──────────────────────────────────────────
|
||||||
|
|
||||||
|
fn ndjson_lines(tokens: &[(&str, bool)]) -> String {
|
||||||
|
tokens
|
||||||
|
.iter()
|
||||||
|
.map(|(content, done)| {
|
||||||
|
format!(
|
||||||
|
r#"{{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{{"role":"assistant","content":"{}"}},"done":{}}}"#,
|
||||||
|
content, done
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_streaming(server_url: &str, batch_size: usize) -> (Result<String, ApiError>, Vec<String>) {
|
||||||
|
let messages = vec![ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "hi".to_string(),
|
||||||
|
}];
|
||||||
|
let (tx, rx) = async_channel::unbounded();
|
||||||
|
let result = send_chat_request_streaming(
|
||||||
|
server_url, "llama3", messages, tx, batch_size, 5000,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut batches = Vec::new();
|
||||||
|
while let Ok(batch) = rx.try_recv() {
|
||||||
|
batches.push(batch);
|
||||||
|
}
|
||||||
|
(result, batches)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn streaming_single_token_returns_full_response() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
let body = ndjson_lines(&[("Hello", true)]);
|
||||||
|
let _mock = server
|
||||||
|
.mock("POST", "/api/chat")
|
||||||
|
.with_status(200)
|
||||||
|
.with_body(body)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (result, _batches) = run_streaming(&server.url(), 100).await;
|
||||||
|
assert_eq!(result.unwrap(), "Hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn streaming_multi_token_accumulates_full_response() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
let body = ndjson_lines(&[("Hello", false), (" world", true)]);
|
||||||
|
let _mock = server
|
||||||
|
.mock("POST", "/api/chat")
|
||||||
|
.with_status(200)
|
||||||
|
.with_body(body)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (result, _) = run_streaming(&server.url(), 100).await;
|
||||||
|
assert_eq!(result.unwrap(), "Hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn streaming_batch_size_flushes_intermediate_batches() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
// 3 tokens, batch_size=2 → first batch sent after 2 tokens, second after done
|
||||||
|
let body = ndjson_lines(&[("a", false), ("b", false), ("c", true)]);
|
||||||
|
let _mock = server
|
||||||
|
.mock("POST", "/api/chat")
|
||||||
|
.with_status(200)
|
||||||
|
.with_body(body)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (result, batches) = run_streaming(&server.url(), 2).await;
|
||||||
|
assert_eq!(result.unwrap(), "abc");
|
||||||
|
// We should have received at least 2 channel messages (one mid-stream, one final)
|
||||||
|
assert!(batches.len() >= 2, "expected intermediate batches, got {:?}", batches);
|
||||||
|
assert_eq!(batches.join(""), "abc");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn streaming_done_with_no_content_returns_empty_response_error() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
// done:true but content is empty
|
||||||
|
let body = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":""},"done":true}"#;
|
||||||
|
let _mock = server
|
||||||
|
.mock("POST", "/api/chat")
|
||||||
|
.with_status(200)
|
||||||
|
.with_body(body)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (result, _) = run_streaming(&server.url(), 100).await;
|
||||||
|
assert!(matches!(result, Err(ApiError::EmptyResponse)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn streaming_bad_status_returns_error() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("POST", "/api/chat")
|
||||||
|
.with_status(503)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (result, _) = run_streaming(&server.url(), 100).await;
|
||||||
|
assert!(matches!(result, Err(ApiError::BadStatus(503))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn streaming_malformed_json_line_is_skipped() {
|
||||||
|
let mut server = mockito::Server::new_async().await;
|
||||||
|
// A bad line in the middle should not abort processing
|
||||||
|
let body = format!(
|
||||||
|
"{}\nnot valid json\n{}",
|
||||||
|
r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hello"},"done":false}"#,
|
||||||
|
r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":" world"},"done":true}"#,
|
||||||
|
);
|
||||||
|
let _mock = server
|
||||||
|
.mock("POST", "/api/chat")
|
||||||
|
.with_status(200)
|
||||||
|
.with_body(body)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (result, _) = run_streaming(&server.url(), 100).await;
|
||||||
|
// Should still accumulate valid tokens despite the bad line
|
||||||
|
assert_eq!(result.unwrap(), "Hello world");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -156,3 +156,61 @@ impl Config {
|
||||||
Ok(config_dir.join("config.toml"))
|
Ok(config_dir.join("config.toml"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_config_has_sensible_values() {
|
||||||
|
let cfg = Config::default();
|
||||||
|
|
||||||
|
// Ollama endpoint
|
||||||
|
assert_eq!(cfg.ollama.url, "http://localhost:11434");
|
||||||
|
assert!(cfg.ollama.timeout_seconds > 0);
|
||||||
|
|
||||||
|
// Context window: 0 would mean every request sends zero messages — bad default.
|
||||||
|
assert!(cfg.ollama.max_context_messages > 0,
|
||||||
|
"max_context_messages must be > 0 or no history is ever sent");
|
||||||
|
|
||||||
|
// Streaming
|
||||||
|
assert!(cfg.streaming.batch_size > 0);
|
||||||
|
assert!(cfg.streaming.batch_timeout_ms > 0);
|
||||||
|
|
||||||
|
// System prompt empty by default (RAG / user sets it when needed)
|
||||||
|
assert!(cfg.ollama.system_prompt.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_roundtrips_through_toml() {
|
||||||
|
// Serialize to TOML and back — verifies the serde field names are stable.
|
||||||
|
// If you rename a field in the struct but forget to update the serde attribute,
|
||||||
|
// existing config files on disk will silently lose that value. This catches it.
|
||||||
|
let original = Config::default();
|
||||||
|
let toml_str = toml::to_string_pretty(&original).unwrap();
|
||||||
|
let parsed: Config = toml::from_str(&toml_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed.ollama.url, original.ollama.url);
|
||||||
|
assert_eq!(parsed.ollama.max_context_messages, original.ollama.max_context_messages);
|
||||||
|
assert_eq!(parsed.streaming.batch_size, original.streaming.batch_size);
|
||||||
|
assert_eq!(parsed.streaming.batch_timeout_ms, original.streaming.batch_timeout_ms);
|
||||||
|
assert_eq!(parsed.ui.window_font_size, original.ui.window_font_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_roundtrip_preserves_custom_values() {
|
||||||
|
let mut cfg = Config::default();
|
||||||
|
cfg.ollama.url = "http://my-server:11434".to_string();
|
||||||
|
cfg.ollama.max_context_messages = 50;
|
||||||
|
cfg.streaming.batch_size = 5;
|
||||||
|
cfg.ollama.system_prompt = "You are a helpful assistant.".to_string();
|
||||||
|
|
||||||
|
let toml_str = toml::to_string_pretty(&cfg).unwrap();
|
||||||
|
let parsed: Config = toml::from_str(&toml_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed.ollama.url, "http://my-server:11434");
|
||||||
|
assert_eq!(parsed.ollama.max_context_messages, 50);
|
||||||
|
assert_eq!(parsed.streaming.batch_size, 5);
|
||||||
|
assert_eq!(parsed.ollama.system_prompt, "You are a helpful assistant.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -181,56 +181,19 @@ impl MarkdownRenderer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process text for streaming, handling think tags in real-time
|
/// Process text for streaming, handling think tags in real-time.
|
||||||
|
///
|
||||||
|
/// Delegates detection to [`parse_think_segments`] and handles GTK insertions per segment.
|
||||||
fn process_streaming_text(&mut self, buffer: &TextBuffer, text: &str, iter: &mut TextIter) -> String {
|
fn process_streaming_text(&mut self, buffer: &TextBuffer, text: &str, iter: &mut TextIter) -> String {
|
||||||
let mut result = String::new();
|
let mut result = String::new();
|
||||||
let mut remaining = text;
|
for segment in parse_think_segments(text, &mut self.in_think_tag) {
|
||||||
|
match segment {
|
||||||
while !remaining.is_empty() {
|
StreamSegment::Normal(s) => result.push_str(&s),
|
||||||
if self.in_think_tag {
|
StreamSegment::ThinkStart => buffer.insert(iter, "\n💭 "),
|
||||||
// We're currently inside a think tag, look for closing tag
|
StreamSegment::Think(s) => buffer.insert_with_tags(iter, &s, &[&self.think_tag]),
|
||||||
if let Some(end_pos) = remaining.find("</think>") {
|
StreamSegment::ThinkEnd => buffer.insert(iter, "\n\n"),
|
||||||
// Found closing tag - stream the remaining think content
|
|
||||||
let final_think_content = &remaining[..end_pos];
|
|
||||||
if !final_think_content.is_empty() {
|
|
||||||
buffer.insert_with_tags(iter, final_think_content, &[&self.think_tag]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the think section
|
|
||||||
buffer.insert(iter, "\n\n");
|
|
||||||
|
|
||||||
// Reset think state
|
|
||||||
self.in_think_tag = false;
|
|
||||||
|
|
||||||
// Continue with text after closing tag
|
|
||||||
remaining = &remaining[end_pos + 8..]; // 8 = "</think>".len()
|
|
||||||
} else {
|
|
||||||
// No closing tag yet, stream the think content as it arrives
|
|
||||||
if !remaining.is_empty() {
|
|
||||||
buffer.insert_with_tags(iter, remaining, &[&self.think_tag]);
|
|
||||||
}
|
|
||||||
break; // Wait for more streaming content
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Not in think tag, look for opening tag
|
|
||||||
if let Some(start_pos) = remaining.find("<think>") {
|
|
||||||
// Add content before think tag to result for normal processing
|
|
||||||
result.push_str(&remaining[..start_pos]);
|
|
||||||
|
|
||||||
// Start think mode and show the think indicator
|
|
||||||
self.in_think_tag = true;
|
|
||||||
buffer.insert(iter, "\n💭 ");
|
|
||||||
|
|
||||||
// Continue with content after opening tag
|
|
||||||
remaining = &remaining[start_pos + 7..]; // 7 = "<think>".len()
|
|
||||||
} else {
|
|
||||||
// No think tag found, add all remaining text to result
|
|
||||||
result.push_str(remaining);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -382,6 +345,68 @@ impl MarkdownRenderer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A parsed segment from streaming think-tag processing.
|
||||||
|
///
|
||||||
|
/// Separating the pure detection logic (see [`parse_think_segments`]) from GTK mutations lets us
|
||||||
|
/// unit-test the state machine without a live display session.
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
enum StreamSegment {
|
||||||
|
/// Plain text that passes through the markdown renderer.
|
||||||
|
Normal(String),
|
||||||
|
/// The `<think>` opening boundary was seen; the GTK layer emits an indicator.
|
||||||
|
ThinkStart,
|
||||||
|
/// Content inside a think block, rendered with the think-tag style.
|
||||||
|
Think(String),
|
||||||
|
/// The `</think>` closing boundary was seen; the GTK layer emits a separator.
|
||||||
|
ThinkEnd,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse `text` into typed [`StreamSegment`]s, updating the in-flight `in_think` cursor.
|
||||||
|
///
|
||||||
|
/// Streaming-safe: a think block may open in one call and close in a later call.
|
||||||
|
fn parse_think_segments(text: &str, in_think: &mut bool) -> Vec<StreamSegment> {
|
||||||
|
let mut segments = Vec::new();
|
||||||
|
let mut remaining = text;
|
||||||
|
|
||||||
|
while !remaining.is_empty() {
|
||||||
|
if *in_think {
|
||||||
|
match remaining.find("</think>") {
|
||||||
|
Some(end_pos) => {
|
||||||
|
let content = &remaining[..end_pos];
|
||||||
|
if !content.is_empty() {
|
||||||
|
segments.push(StreamSegment::Think(content.to_string()));
|
||||||
|
}
|
||||||
|
segments.push(StreamSegment::ThinkEnd);
|
||||||
|
*in_think = false;
|
||||||
|
remaining = &remaining[end_pos + 8..]; // skip "</think>"
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
segments.push(StreamSegment::Think(remaining.to_string()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match remaining.find("<think>") {
|
||||||
|
Some(start_pos) => {
|
||||||
|
let before = &remaining[..start_pos];
|
||||||
|
if !before.is_empty() {
|
||||||
|
segments.push(StreamSegment::Normal(before.to_string()));
|
||||||
|
}
|
||||||
|
segments.push(StreamSegment::ThinkStart);
|
||||||
|
*in_think = true;
|
||||||
|
remaining = &remaining[start_pos + 7..]; // skip "<think>"
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
segments.push(StreamSegment::Normal(remaining.to_string()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
segments
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to parse color strings (hex format) into RGBA
|
/// Helper function to parse color strings (hex format) into RGBA
|
||||||
fn parse_color(color_str: &str) -> Result<gtk4::gdk::RGBA, Box<dyn std::error::Error>> {
|
fn parse_color(color_str: &str) -> Result<gtk4::gdk::RGBA, Box<dyn std::error::Error>> {
|
||||||
let color_str = color_str.trim_start_matches('#');
|
let color_str = color_str.trim_start_matches('#');
|
||||||
|
|
@ -396,3 +421,130 @@ fn parse_color(color_str: &str) -> Result<gtk4::gdk::RGBA, Box<dyn std::error::E
|
||||||
|
|
||||||
Ok(gtk4::gdk::RGBA::new(r, g, b, 1.0))
|
Ok(gtk4::gdk::RGBA::new(r, g, b, 1.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
// ── parse_color ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_color_red() {
|
||||||
|
let c = parse_color("#ff0000").unwrap();
|
||||||
|
assert!((c.red() - 1.0).abs() < 1e-4);
|
||||||
|
assert!((c.green() - 0.0).abs() < 1e-4);
|
||||||
|
assert!((c.blue() - 0.0).abs() < 1e-4);
|
||||||
|
assert!((c.alpha() - 1.0).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_color_black() {
|
||||||
|
let c = parse_color("#000000").unwrap();
|
||||||
|
assert!((c.red() - 0.0).abs() < 1e-4);
|
||||||
|
assert!((c.green() - 0.0).abs() < 1e-4);
|
||||||
|
assert!((c.blue() - 0.0).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_color_white() {
|
||||||
|
let c = parse_color("#ffffff").unwrap();
|
||||||
|
assert!((c.red() - 1.0).abs() < 1e-4);
|
||||||
|
assert!((c.green() - 1.0).abs() < 1e-4);
|
||||||
|
assert!((c.blue() - 1.0).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_color_short_hex_is_error() {
|
||||||
|
assert!(parse_color("#fff").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_color_non_hex_chars_is_error() {
|
||||||
|
assert!(parse_color("#zzzzzz").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_color_empty_is_error() {
|
||||||
|
assert!(parse_color("").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── parse_think_segments ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn plain_text_produces_single_normal_segment() {
|
||||||
|
let mut in_think = false;
|
||||||
|
let segs = parse_think_segments("hello world", &mut in_think);
|
||||||
|
assert_eq!(segs, vec![StreamSegment::Normal("hello world".into())]);
|
||||||
|
assert!(!in_think);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn think_block_in_middle_produces_all_segments() {
|
||||||
|
let mut in_think = false;
|
||||||
|
let segs = parse_think_segments("before <think>thinking</think> after", &mut in_think);
|
||||||
|
assert_eq!(segs, vec![
|
||||||
|
StreamSegment::Normal("before ".into()),
|
||||||
|
StreamSegment::ThinkStart,
|
||||||
|
StreamSegment::Think("thinking".into()),
|
||||||
|
StreamSegment::ThinkEnd,
|
||||||
|
StreamSegment::Normal(" after".into()),
|
||||||
|
]);
|
||||||
|
assert!(!in_think);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unclosed_think_tag_leaves_in_think_true() {
|
||||||
|
let mut in_think = false;
|
||||||
|
let segs = parse_think_segments("start <think>partial", &mut in_think);
|
||||||
|
assert_eq!(segs, vec![
|
||||||
|
StreamSegment::Normal("start ".into()),
|
||||||
|
StreamSegment::ThinkStart,
|
||||||
|
StreamSegment::Think("partial".into()),
|
||||||
|
]);
|
||||||
|
assert!(in_think);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn closing_tag_in_second_call_closes_correctly() {
|
||||||
|
let mut in_think = true; // simulate carrying over from previous call
|
||||||
|
let segs = parse_think_segments("rest</think> normal", &mut in_think);
|
||||||
|
assert_eq!(segs, vec![
|
||||||
|
StreamSegment::Think("rest".into()),
|
||||||
|
StreamSegment::ThinkEnd,
|
||||||
|
StreamSegment::Normal(" normal".into()),
|
||||||
|
]);
|
||||||
|
assert!(!in_think);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn continuation_while_in_think_produces_think_segment() {
|
||||||
|
let mut in_think = true;
|
||||||
|
let segs = parse_think_segments("more thinking...", &mut in_think);
|
||||||
|
assert_eq!(segs, vec![StreamSegment::Think("more thinking...".into())]);
|
||||||
|
assert!(in_think);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_think_block_produces_start_and_end_only() {
|
||||||
|
let mut in_think = false;
|
||||||
|
let segs = parse_think_segments("<think></think>", &mut in_think);
|
||||||
|
assert_eq!(segs, vec![
|
||||||
|
StreamSegment::ThinkStart,
|
||||||
|
StreamSegment::ThinkEnd,
|
||||||
|
]);
|
||||||
|
assert!(!in_think);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn think_block_at_very_start() {
|
||||||
|
let mut in_think = false;
|
||||||
|
let segs = parse_think_segments("<think>reasoning</think>answer", &mut in_think);
|
||||||
|
assert_eq!(segs, vec![
|
||||||
|
StreamSegment::ThinkStart,
|
||||||
|
StreamSegment::Think("reasoning".into()),
|
||||||
|
StreamSegment::ThinkEnd,
|
||||||
|
StreamSegment::Normal("answer".into()),
|
||||||
|
]);
|
||||||
|
assert!(!in_think);
|
||||||
|
}
|
||||||
|
}
|
||||||
104
src/state.rs
104
src/state.rs
|
|
@ -109,3 +109,107 @@ impl AppState {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn make_state() -> AppState {
|
||||||
|
AppState {
|
||||||
|
conversation: Vec::new(),
|
||||||
|
ollama_url: "http://localhost:11434".into(),
|
||||||
|
is_generating: false,
|
||||||
|
button_state: ButtonState::Send,
|
||||||
|
current_task: None,
|
||||||
|
selected_model: None,
|
||||||
|
status_message: "Ready".into(),
|
||||||
|
system_prompt: None,
|
||||||
|
config: Config::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn set_generating_true_sets_stop_state() {
|
||||||
|
let mut state = make_state();
|
||||||
|
state.set_generating(true);
|
||||||
|
assert!(state.is_generating);
|
||||||
|
assert_eq!(state.button_state, ButtonState::Stop);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn set_generating_false_sets_send_state() {
|
||||||
|
let mut state = make_state();
|
||||||
|
state.is_generating = true;
|
||||||
|
state.button_state = ButtonState::Stop;
|
||||||
|
state.set_generating(false);
|
||||||
|
assert!(!state.is_generating);
|
||||||
|
assert_eq!(state.button_state, ButtonState::Send);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_user_message_appends_with_correct_role() {
|
||||||
|
let mut state = make_state();
|
||||||
|
state.add_user_message("hello".into());
|
||||||
|
assert_eq!(state.conversation.len(), 1);
|
||||||
|
assert_eq!(state.conversation[0].role, "user");
|
||||||
|
assert_eq!(state.conversation[0].content, "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_assistant_message_appends_with_correct_role() {
|
||||||
|
let mut state = make_state();
|
||||||
|
state.add_assistant_message("hi there".into());
|
||||||
|
assert_eq!(state.conversation.len(), 1);
|
||||||
|
assert_eq!(state.conversation[0].role, "assistant");
|
||||||
|
assert_eq!(state.conversation[0].content, "hi there");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversation_preserves_insertion_order() {
|
||||||
|
let mut state = make_state();
|
||||||
|
state.add_user_message("first".into());
|
||||||
|
state.add_assistant_message("second".into());
|
||||||
|
state.add_user_message("third".into());
|
||||||
|
assert_eq!(state.conversation.len(), 3);
|
||||||
|
assert_eq!(state.conversation[0].role, "user");
|
||||||
|
assert_eq!(state.conversation[1].role, "assistant");
|
||||||
|
assert_eq!(state.conversation[2].role, "user");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn set_status_updates_message() {
|
||||||
|
let mut state = make_state();
|
||||||
|
state.set_status("Loading models...".into());
|
||||||
|
assert_eq!(state.status_message, "Loading models...");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn abort_current_task_without_task_resets_state() {
|
||||||
|
let mut state = make_state();
|
||||||
|
state.is_generating = true;
|
||||||
|
state.button_state = ButtonState::Stop;
|
||||||
|
state.abort_current_task();
|
||||||
|
assert!(!state.is_generating);
|
||||||
|
assert_eq!(state.button_state, ButtonState::Send);
|
||||||
|
assert_eq!(state.status_message, "Generation stopped");
|
||||||
|
assert!(state.current_task.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn abort_current_task_aborts_running_task() {
|
||||||
|
let mut state = make_state();
|
||||||
|
// Spawn a task that sleeps forever so we can verify it gets aborted
|
||||||
|
let handle = tokio::spawn(async {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await;
|
||||||
|
});
|
||||||
|
state.current_task = Some(handle);
|
||||||
|
state.is_generating = true;
|
||||||
|
state.button_state = ButtonState::Stop;
|
||||||
|
|
||||||
|
state.abort_current_task();
|
||||||
|
|
||||||
|
assert!(state.current_task.is_none());
|
||||||
|
assert!(!state.is_generating);
|
||||||
|
assert_eq!(state.status_message, "Generation stopped");
|
||||||
|
}
|
||||||
|
}
|
||||||
45
src/types.rs
45
src/types.rs
|
|
@ -1,5 +1,7 @@
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// --- Types ---
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ChatMessage {
|
pub struct ChatMessage {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
|
@ -37,3 +39,46 @@ pub struct ModelInfo {
|
||||||
pub struct ModelsResponse {
|
pub struct ModelsResponse {
|
||||||
pub models: Vec<ModelInfo>,
|
pub models: Vec<ModelInfo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
// `cargo test` will run these. This is the standard Rust pattern:
|
||||||
|
// put tests in a `#[cfg(test)]` module so they're compiled only when testing.
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_message_serializes_and_deserializes() {
|
||||||
|
// Verify that serde roundtrips correctly — if you ever rename a field
|
||||||
|
// or change its type, this test catches it before it reaches the API.
|
||||||
|
let msg = ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hello, world!".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
let decoded: ChatMessage = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(decoded.role, msg.role);
|
||||||
|
assert_eq!(decoded.content, msg.content);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_request_includes_stream_flag() {
|
||||||
|
let req = ChatRequest {
|
||||||
|
model: "llama3".to_string(),
|
||||||
|
messages: vec![],
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&req).unwrap();
|
||||||
|
// The Ollama API requires `"stream": true` in the body.
|
||||||
|
assert!(json.contains("\"stream\":true"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stream_response_parses_done_flag() {
|
||||||
|
// Real payload shape from the Ollama streaming API.
|
||||||
|
let raw = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hi"},"done":true}"#;
|
||||||
|
let resp: StreamResponse = serde_json::from_str(raw).unwrap();
|
||||||
|
assert!(resp.done);
|
||||||
|
assert_eq!(resp.message.content, "Hi");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -243,7 +243,6 @@ fn setup_streaming_handlers(
|
||||||
accumulated_content.push_str(&content_batch);
|
accumulated_content.push_str(&content_batch);
|
||||||
let config = shared_state_streaming.borrow().config.clone();
|
let config = shared_state_streaming.borrow().config.clone();
|
||||||
chat_view_content.update_streaming_markdown(&response_mark_clone, &accumulated_content, &config);
|
chat_view_content.update_streaming_markdown(&response_mark_clone, &accumulated_content, &config);
|
||||||
chat_view_content.scroll_to_bottom();
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue