Compare commits

...
Sign in to create a new pull request.

2 commits

9 changed files with 720 additions and 126 deletions

View file

@ -15,3 +15,8 @@ dirs = "6.0"
futures-util = "0.3"
async-channel = "2.3"
pulldown-cmark = { version = "0.13.0", default-features = false, features = ["html"] }
thiserror = "2"
[dev-dependencies]
mockito = "1"
tempfile = "3"

View file

@ -61,6 +61,10 @@ stop_button = "#dc3545"
[ollama]
url = "http://localhost:11434"
timeout_seconds = 120
[streaming]
batch_size = 20
batch_timeout_ms = 100
```
## Building

View file

@ -1,13 +1,29 @@
use std::sync::{Arc, Mutex};
use futures_util::StreamExt;
use tokio::time::{timeout, Duration};
use crate::types::{ChatMessage, ChatRequest, ModelInfo, ModelsResponse, StreamResponse};
pub async fn fetch_models(base_url: &str) -> Result<Vec<ModelInfo>, Box<dyn std::error::Error + Send + Sync>> {
/// Typed errors for the Ollama API layer. Using `thiserror` means callers can match
/// on exactly what went wrong instead of downcasting a `Box<dyn Error>`.
#[derive(Debug, thiserror::Error)]
pub enum ApiError {
#[error("HTTP request failed: {0}")]
Http(#[from] reqwest::Error),
#[error("Request timed out")]
Timeout,
#[error("Server returned error status {0}")]
BadStatus(u16),
#[error("Failed to parse response: {0}")]
Parse(#[from] serde_json::Error),
#[error("Model returned empty response")]
EmptyResponse,
}
pub async fn fetch_models(base_url: &str) -> Result<Vec<ModelInfo>, ApiError> {
let url = format!("{}/api/tags", base_url);
// Add timeout to prevent hanging
let response = timeout(Duration::from_secs(10), reqwest::get(&url)).await??;
let response = timeout(Duration::from_secs(10), reqwest::get(&url))
.await
.map_err(|_| ApiError::Timeout)??;
let models_response: ModelsResponse = response.json().await?;
Ok(models_response.models)
}
@ -15,13 +31,11 @@ pub async fn fetch_models(base_url: &str) -> Result<Vec<ModelInfo>, Box<dyn std:
pub async fn send_chat_request_streaming(
base_url: &str,
model: &str,
conversation: &Arc<Mutex<Vec<ChatMessage>>>,
messages: Vec<ChatMessage>,
token_sender: async_channel::Sender<String>,
) -> Result<(String, Option<String>), Box<dyn std::error::Error + Send + Sync>> {
let messages = {
let conversation = conversation.lock().unwrap();
conversation.iter().cloned().collect::<Vec<_>>()
};
batch_size: usize,
batch_timeout_ms: u64,
) -> Result<String, ApiError> {
let request = ChatRequest {
model: model.to_string(),
@ -42,15 +56,14 @@ pub async fn send_chat_request_streaming(
.await?;
if !response.status().is_success() {
return Err(format!("API request failed with status: {}", response.status()).into());
return Err(ApiError::BadStatus(response.status().as_u16()));
}
let mut stream = response.bytes_stream();
let mut full_response = String::new();
let mut current_batch = String::new();
let mut tokens_since_last_send = 0;
const BATCH_SIZE: usize = 20;
const BATCH_TIMEOUT: Duration = Duration::from_millis(100);
let batch_timeout = Duration::from_millis(batch_timeout_ms);
let mut last_send = tokio::time::Instant::now();
@ -74,8 +87,8 @@ pub async fn send_chat_request_streaming(
}
// Send batch if conditions are met
let should_send = tokens_since_last_send >= BATCH_SIZE
|| last_send.elapsed() >= BATCH_TIMEOUT
let should_send = tokens_since_last_send >= batch_size
|| last_send.elapsed() >= batch_timeout
|| stream_response.done;
if should_send {
@ -115,8 +128,195 @@ pub async fn send_chat_request_streaming(
drop(token_sender);
if full_response.is_empty() {
return Err("No response received from the model".into());
return Err(ApiError::EmptyResponse);
}
Ok((full_response, None))
Ok(full_response)
}
#[cfg(test)]
mod tests {
use super::*;
// ── fetch_models ─────────────────────────────────────────────────────────
#[tokio::test]
async fn fetch_models_returns_model_list() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", "/api/tags")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"models":[{"name":"llama3","modified_at":"2024-01-01T00:00:00Z","size":4000000}]}"#)
.create_async()
.await;
let models = fetch_models(&server.url()).await.unwrap();
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "llama3");
}
#[tokio::test]
async fn fetch_models_bad_status_returns_error() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", "/api/tags")
.with_status(500)
.create_async()
.await;
let err = fetch_models(&server.url()).await.unwrap_err();
// reqwest treats non-success as an error only if we explicitly check;
// here fetch_models passes the status through response.json() which will
// fail because body is empty — so we get an Http or Parse error.
// The important thing: it is an error, not a success.
assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_) | ApiError::BadStatus(_)));
}
#[tokio::test]
async fn fetch_models_bad_json_returns_parse_error() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", "/api/tags")
.with_status(200)
.with_header("content-type", "application/json")
.with_body("not json")
.create_async()
.await;
let err = fetch_models(&server.url()).await.unwrap_err();
assert!(matches!(err, ApiError::Http(_) | ApiError::Parse(_)));
}
// ── send_chat_request_streaming ──────────────────────────────────────────
fn ndjson_lines(tokens: &[(&str, bool)]) -> String {
tokens
.iter()
.map(|(content, done)| {
format!(
r#"{{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{{"role":"assistant","content":"{}"}},"done":{}}}"#,
content, done
)
})
.collect::<Vec<_>>()
.join("\n")
}
async fn run_streaming(server_url: &str, batch_size: usize) -> (Result<String, ApiError>, Vec<String>) {
let messages = vec![ChatMessage {
role: "user".to_string(),
content: "hi".to_string(),
}];
let (tx, rx) = async_channel::unbounded();
let result = send_chat_request_streaming(
server_url, "llama3", messages, tx, batch_size, 5000,
)
.await;
let mut batches = Vec::new();
while let Ok(batch) = rx.try_recv() {
batches.push(batch);
}
(result, batches)
}
#[tokio::test]
async fn streaming_single_token_returns_full_response() {
let mut server = mockito::Server::new_async().await;
let body = ndjson_lines(&[("Hello", true)]);
let _mock = server
.mock("POST", "/api/chat")
.with_status(200)
.with_body(body)
.create_async()
.await;
let (result, _batches) = run_streaming(&server.url(), 100).await;
assert_eq!(result.unwrap(), "Hello");
}
#[tokio::test]
async fn streaming_multi_token_accumulates_full_response() {
let mut server = mockito::Server::new_async().await;
let body = ndjson_lines(&[("Hello", false), (" world", true)]);
let _mock = server
.mock("POST", "/api/chat")
.with_status(200)
.with_body(body)
.create_async()
.await;
let (result, _) = run_streaming(&server.url(), 100).await;
assert_eq!(result.unwrap(), "Hello world");
}
#[tokio::test]
async fn streaming_batch_size_flushes_intermediate_batches() {
let mut server = mockito::Server::new_async().await;
// 3 tokens, batch_size=2 → first batch sent after 2 tokens, second after done
let body = ndjson_lines(&[("a", false), ("b", false), ("c", true)]);
let _mock = server
.mock("POST", "/api/chat")
.with_status(200)
.with_body(body)
.create_async()
.await;
let (result, batches) = run_streaming(&server.url(), 2).await;
assert_eq!(result.unwrap(), "abc");
// We should have received at least 2 channel messages (one mid-stream, one final)
assert!(batches.len() >= 2, "expected intermediate batches, got {:?}", batches);
assert_eq!(batches.join(""), "abc");
}
#[tokio::test]
async fn streaming_done_with_no_content_returns_empty_response_error() {
let mut server = mockito::Server::new_async().await;
// done:true but content is empty
let body = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":""},"done":true}"#;
let _mock = server
.mock("POST", "/api/chat")
.with_status(200)
.with_body(body)
.create_async()
.await;
let (result, _) = run_streaming(&server.url(), 100).await;
assert!(matches!(result, Err(ApiError::EmptyResponse)));
}
#[tokio::test]
async fn streaming_bad_status_returns_error() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("POST", "/api/chat")
.with_status(503)
.create_async()
.await;
let (result, _) = run_streaming(&server.url(), 100).await;
assert!(matches!(result, Err(ApiError::BadStatus(503))));
}
#[tokio::test]
async fn streaming_malformed_json_line_is_skipped() {
let mut server = mockito::Server::new_async().await;
// A bad line in the middle should not abort processing
let body = format!(
"{}\nnot valid json\n{}",
r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hello"},"done":false}"#,
r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":" world"},"done":true}"#,
);
let _mock = server
.mock("POST", "/api/chat")
.with_status(200)
.with_body(body)
.create_async()
.await;
let (result, _) = run_streaming(&server.url(), 100).await;
// Should still accumulate valid tokens despite the bad line
assert_eq!(result.unwrap(), "Hello world");
}
}

View file

@ -7,6 +7,7 @@ pub struct Config {
pub ui: UiConfig,
pub colors: ColorConfig,
pub ollama: OllamaConfig,
pub streaming: StreamingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -39,6 +40,20 @@ pub struct ColorConfig {
pub struct OllamaConfig {
pub url: String,
pub timeout_seconds: u64,
/// Maximum number of conversation turns sent to the model (most recent N messages).
/// Keeps context within the model's limit. Set higher for longer memory.
pub max_context_messages: usize,
/// Optional system prompt prepended to every conversation.
/// Leave empty ("") to disable. RAG can override this at runtime via AppState.
pub system_prompt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamingConfig {
/// Number of tokens to accumulate before flushing to the UI.
pub batch_size: usize,
/// Maximum milliseconds to wait before flushing a partial batch.
pub batch_timeout_ms: u64,
}
impl Default for Config {
@ -47,6 +62,7 @@ impl Default for Config {
ui: UiConfig::default(),
colors: ColorConfig::default(),
ollama: OllamaConfig::default(),
streaming: StreamingConfig::default(),
}
}
}
@ -88,6 +104,17 @@ impl Default for OllamaConfig {
Self {
url: "http://localhost:11434".to_string(),
timeout_seconds: 120,
max_context_messages: 20,
system_prompt: String::new(),
}
}
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
batch_size: 20,
batch_timeout_ms: 100,
}
}
}
@ -129,3 +156,61 @@ impl Config {
Ok(config_dir.join("config.toml"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_has_sensible_values() {
let cfg = Config::default();
// Ollama endpoint
assert_eq!(cfg.ollama.url, "http://localhost:11434");
assert!(cfg.ollama.timeout_seconds > 0);
// Context window: 0 would mean every request sends zero messages — bad default.
assert!(cfg.ollama.max_context_messages > 0,
"max_context_messages must be > 0 or no history is ever sent");
// Streaming
assert!(cfg.streaming.batch_size > 0);
assert!(cfg.streaming.batch_timeout_ms > 0);
// System prompt empty by default (RAG / user sets it when needed)
assert!(cfg.ollama.system_prompt.is_empty());
}
#[test]
fn config_roundtrips_through_toml() {
// Serialize to TOML and back — verifies the serde field names are stable.
// If you rename a field in the struct but forget to update the serde attribute,
// existing config files on disk will silently lose that value. This catches it.
let original = Config::default();
let toml_str = toml::to_string_pretty(&original).unwrap();
let parsed: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.ollama.url, original.ollama.url);
assert_eq!(parsed.ollama.max_context_messages, original.ollama.max_context_messages);
assert_eq!(parsed.streaming.batch_size, original.streaming.batch_size);
assert_eq!(parsed.streaming.batch_timeout_ms, original.streaming.batch_timeout_ms);
assert_eq!(parsed.ui.window_font_size, original.ui.window_font_size);
}
#[test]
fn config_roundtrip_preserves_custom_values() {
let mut cfg = Config::default();
cfg.ollama.url = "http://my-server:11434".to_string();
cfg.ollama.max_context_messages = 50;
cfg.streaming.batch_size = 5;
cfg.ollama.system_prompt = "You are a helpful assistant.".to_string();
let toml_str = toml::to_string_pretty(&cfg).unwrap();
let parsed: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.ollama.url, "http://my-server:11434");
assert_eq!(parsed.ollama.max_context_messages, 50);
assert_eq!(parsed.streaming.batch_size, 5);
assert_eq!(parsed.ollama.system_prompt, "You are a helpful assistant.");
}
}

View file

@ -25,7 +25,6 @@ pub struct MarkdownRenderer {
tags_setup: bool,
// State for streaming think tag processing
in_think_tag: bool,
think_buffer: String,
}
impl MarkdownRenderer {
@ -47,7 +46,6 @@ impl MarkdownRenderer {
format_stack: Vec::new(),
tags_setup: false,
in_think_tag: false,
think_buffer: String::new(),
}
}
@ -183,58 +181,19 @@ impl MarkdownRenderer {
}
}
/// Process text for streaming, handling think tags in real-time
/// Process text for streaming, handling think tags in real-time.
///
/// Delegates detection to [`parse_think_segments`] and handles GTK insertions per segment.
fn process_streaming_text(&mut self, buffer: &TextBuffer, text: &str, iter: &mut TextIter) -> String {
let mut result = String::new();
let mut remaining = text;
while !remaining.is_empty() {
if self.in_think_tag {
// We're currently inside a think tag, look for closing tag
if let Some(end_pos) = remaining.find("</think>") {
// Found closing tag - stream the remaining think content
let final_think_content = &remaining[..end_pos];
if !final_think_content.is_empty() {
buffer.insert_with_tags(iter, final_think_content, &[&self.think_tag]);
}
// Close the think section
buffer.insert(iter, "\n\n");
// Reset think state
self.in_think_tag = false;
self.think_buffer.clear();
// Continue with text after closing tag
remaining = &remaining[end_pos + 8..]; // 8 = "</think>".len()
} else {
// No closing tag yet, stream the think content as it arrives
if !remaining.is_empty() {
buffer.insert_with_tags(iter, remaining, &[&self.think_tag]);
}
break; // Wait for more streaming content
}
} else {
// Not in think tag, look for opening tag
if let Some(start_pos) = remaining.find("<think>") {
// Add content before think tag to result for normal processing
result.push_str(&remaining[..start_pos]);
// Start think mode and show the think indicator
self.in_think_tag = true;
self.think_buffer.clear();
buffer.insert(iter, "\n💭 ");
// Continue with content after opening tag
remaining = &remaining[start_pos + 7..]; // 7 = "<think>".len()
} else {
// No think tag found, add all remaining text to result
result.push_str(remaining);
break;
}
for segment in parse_think_segments(text, &mut self.in_think_tag) {
match segment {
StreamSegment::Normal(s) => result.push_str(&s),
StreamSegment::ThinkStart => buffer.insert(iter, "\n💭 "),
StreamSegment::Think(s) => buffer.insert_with_tags(iter, &s, &[&self.think_tag]),
StreamSegment::ThinkEnd => buffer.insert(iter, "\n\n"),
}
}
result
}
@ -386,6 +345,68 @@ impl MarkdownRenderer {
}
}
/// A parsed segment from streaming think-tag processing.
///
/// Separating the pure detection logic (see [`parse_think_segments`]) from GTK mutations lets us
/// unit-test the state machine without a live display session.
#[derive(Debug, PartialEq)]
enum StreamSegment {
/// Plain text that passes through the markdown renderer.
Normal(String),
/// The `<think>` opening boundary was seen; the GTK layer emits an indicator.
ThinkStart,
/// Content inside a think block, rendered with the think-tag style.
Think(String),
/// The `</think>` closing boundary was seen; the GTK layer emits a separator.
ThinkEnd,
}
/// Parse `text` into typed [`StreamSegment`]s, updating the in-flight `in_think` cursor.
///
/// Streaming-safe: a think block may open in one call and close in a later call.
fn parse_think_segments(text: &str, in_think: &mut bool) -> Vec<StreamSegment> {
let mut segments = Vec::new();
let mut remaining = text;
while !remaining.is_empty() {
if *in_think {
match remaining.find("</think>") {
Some(end_pos) => {
let content = &remaining[..end_pos];
if !content.is_empty() {
segments.push(StreamSegment::Think(content.to_string()));
}
segments.push(StreamSegment::ThinkEnd);
*in_think = false;
remaining = &remaining[end_pos + 8..]; // skip "</think>"
}
None => {
segments.push(StreamSegment::Think(remaining.to_string()));
break;
}
}
} else {
match remaining.find("<think>") {
Some(start_pos) => {
let before = &remaining[..start_pos];
if !before.is_empty() {
segments.push(StreamSegment::Normal(before.to_string()));
}
segments.push(StreamSegment::ThinkStart);
*in_think = true;
remaining = &remaining[start_pos + 7..]; // skip "<think>"
}
None => {
segments.push(StreamSegment::Normal(remaining.to_string()));
break;
}
}
}
}
segments
}
/// Helper function to parse color strings (hex format) into RGBA
fn parse_color(color_str: &str) -> Result<gtk4::gdk::RGBA, Box<dyn std::error::Error>> {
let color_str = color_str.trim_start_matches('#');
@ -400,3 +421,130 @@ fn parse_color(color_str: &str) -> Result<gtk4::gdk::RGBA, Box<dyn std::error::E
Ok(gtk4::gdk::RGBA::new(r, g, b, 1.0))
}
#[cfg(test)]
mod tests {
use super::*;
// ── parse_color ──────────────────────────────────────────────────────────
#[test]
fn parse_color_red() {
let c = parse_color("#ff0000").unwrap();
assert!((c.red() - 1.0).abs() < 1e-4);
assert!((c.green() - 0.0).abs() < 1e-4);
assert!((c.blue() - 0.0).abs() < 1e-4);
assert!((c.alpha() - 1.0).abs() < 1e-4);
}
#[test]
fn parse_color_black() {
let c = parse_color("#000000").unwrap();
assert!((c.red() - 0.0).abs() < 1e-4);
assert!((c.green() - 0.0).abs() < 1e-4);
assert!((c.blue() - 0.0).abs() < 1e-4);
}
#[test]
fn parse_color_white() {
let c = parse_color("#ffffff").unwrap();
assert!((c.red() - 1.0).abs() < 1e-4);
assert!((c.green() - 1.0).abs() < 1e-4);
assert!((c.blue() - 1.0).abs() < 1e-4);
}
#[test]
fn parse_color_short_hex_is_error() {
assert!(parse_color("#fff").is_err());
}
#[test]
fn parse_color_non_hex_chars_is_error() {
assert!(parse_color("#zzzzzz").is_err());
}
#[test]
fn parse_color_empty_is_error() {
assert!(parse_color("").is_err());
}
// ── parse_think_segments ─────────────────────────────────────────────────
#[test]
fn plain_text_produces_single_normal_segment() {
let mut in_think = false;
let segs = parse_think_segments("hello world", &mut in_think);
assert_eq!(segs, vec![StreamSegment::Normal("hello world".into())]);
assert!(!in_think);
}
#[test]
fn think_block_in_middle_produces_all_segments() {
let mut in_think = false;
let segs = parse_think_segments("before <think>thinking</think> after", &mut in_think);
assert_eq!(segs, vec![
StreamSegment::Normal("before ".into()),
StreamSegment::ThinkStart,
StreamSegment::Think("thinking".into()),
StreamSegment::ThinkEnd,
StreamSegment::Normal(" after".into()),
]);
assert!(!in_think);
}
#[test]
fn unclosed_think_tag_leaves_in_think_true() {
let mut in_think = false;
let segs = parse_think_segments("start <think>partial", &mut in_think);
assert_eq!(segs, vec![
StreamSegment::Normal("start ".into()),
StreamSegment::ThinkStart,
StreamSegment::Think("partial".into()),
]);
assert!(in_think);
}
#[test]
fn closing_tag_in_second_call_closes_correctly() {
let mut in_think = true; // simulate carrying over from previous call
let segs = parse_think_segments("rest</think> normal", &mut in_think);
assert_eq!(segs, vec![
StreamSegment::Think("rest".into()),
StreamSegment::ThinkEnd,
StreamSegment::Normal(" normal".into()),
]);
assert!(!in_think);
}
#[test]
fn continuation_while_in_think_produces_think_segment() {
let mut in_think = true;
let segs = parse_think_segments("more thinking...", &mut in_think);
assert_eq!(segs, vec![StreamSegment::Think("more thinking...".into())]);
assert!(in_think);
}
#[test]
fn empty_think_block_produces_start_and_end_only() {
let mut in_think = false;
let segs = parse_think_segments("<think></think>", &mut in_think);
assert_eq!(segs, vec![
StreamSegment::ThinkStart,
StreamSegment::ThinkEnd,
]);
assert!(!in_think);
}
#[test]
fn think_block_at_very_start() {
let mut in_think = false;
let segs = parse_think_segments("<think>reasoning</think>answer", &mut in_think);
assert_eq!(segs, vec![
StreamSegment::ThinkStart,
StreamSegment::Think("reasoning".into()),
StreamSegment::ThinkEnd,
StreamSegment::Normal("answer".into()),
]);
assert!(!in_think);
}
}

View file

@ -6,29 +6,23 @@ use crate::config::Config;
pub type SharedState = Rc<RefCell<AppState>>;
#[derive(Debug)]
/// Application-level errors. Uses `thiserror` so each variant gets a clear, typed
/// message without boilerplate. Callers can match on the variant to handle errors
/// differently (e.g. show a dialog for Config vs. a status-bar message for Api).
#[derive(Debug, thiserror::Error)]
pub enum AppError {
#[error("API error: {0}")]
Api(String),
#[error("UI error: {0}")]
Ui(String),
#[error("State error: {0}")]
State(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("Config error: {0}")]
Config(String),
}
impl std::fmt::Display for AppError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AppError::Api(msg) => write!(f, "API Error: {}", msg),
AppError::Ui(msg) => write!(f, "UI Error: {}", msg),
AppError::State(msg) => write!(f, "State Error: {}", msg),
AppError::Validation(msg) => write!(f, "Validation Error: {}", msg),
AppError::Config(msg) => write!(f, "Config Error: {}", msg),
}
}
}
impl std::error::Error for AppError {}
pub type AppResult<T> = Result<T, AppError>;
#[derive(Debug, Clone, Copy, PartialEq)]
@ -45,6 +39,9 @@ pub struct AppState {
pub current_task: Option<JoinHandle<()>>,
pub selected_model: Option<String>,
pub status_message: String,
/// System prompt prepended to every request. Initialized from config but can be
/// overridden at runtime (e.g. by a RAG pipeline to inject retrieved context).
pub system_prompt: Option<String>,
pub config: Config,
}
@ -55,6 +52,12 @@ impl Default for AppState {
Config::default()
});
let system_prompt = if config.ollama.system_prompt.is_empty() {
None
} else {
Some(config.ollama.system_prompt.clone())
};
Self {
conversation: Vec::new(),
ollama_url: config.ollama.url.clone(),
@ -63,6 +66,7 @@ impl Default for AppState {
current_task: None,
selected_model: None,
status_message: "Ready".to_string(),
system_prompt,
config,
}
}
@ -105,3 +109,107 @@ impl AppState {
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_state() -> AppState {
AppState {
conversation: Vec::new(),
ollama_url: "http://localhost:11434".into(),
is_generating: false,
button_state: ButtonState::Send,
current_task: None,
selected_model: None,
status_message: "Ready".into(),
system_prompt: None,
config: Config::default(),
}
}
#[test]
fn set_generating_true_sets_stop_state() {
let mut state = make_state();
state.set_generating(true);
assert!(state.is_generating);
assert_eq!(state.button_state, ButtonState::Stop);
}
#[test]
fn set_generating_false_sets_send_state() {
let mut state = make_state();
state.is_generating = true;
state.button_state = ButtonState::Stop;
state.set_generating(false);
assert!(!state.is_generating);
assert_eq!(state.button_state, ButtonState::Send);
}
#[test]
fn add_user_message_appends_with_correct_role() {
let mut state = make_state();
state.add_user_message("hello".into());
assert_eq!(state.conversation.len(), 1);
assert_eq!(state.conversation[0].role, "user");
assert_eq!(state.conversation[0].content, "hello");
}
#[test]
fn add_assistant_message_appends_with_correct_role() {
let mut state = make_state();
state.add_assistant_message("hi there".into());
assert_eq!(state.conversation.len(), 1);
assert_eq!(state.conversation[0].role, "assistant");
assert_eq!(state.conversation[0].content, "hi there");
}
#[test]
fn conversation_preserves_insertion_order() {
let mut state = make_state();
state.add_user_message("first".into());
state.add_assistant_message("second".into());
state.add_user_message("third".into());
assert_eq!(state.conversation.len(), 3);
assert_eq!(state.conversation[0].role, "user");
assert_eq!(state.conversation[1].role, "assistant");
assert_eq!(state.conversation[2].role, "user");
}
#[test]
fn set_status_updates_message() {
let mut state = make_state();
state.set_status("Loading models...".into());
assert_eq!(state.status_message, "Loading models...");
}
#[test]
fn abort_current_task_without_task_resets_state() {
let mut state = make_state();
state.is_generating = true;
state.button_state = ButtonState::Stop;
state.abort_current_task();
assert!(!state.is_generating);
assert_eq!(state.button_state, ButtonState::Send);
assert_eq!(state.status_message, "Generation stopped");
assert!(state.current_task.is_none());
}
#[tokio::test]
async fn abort_current_task_aborts_running_task() {
let mut state = make_state();
// Spawn a task that sleeps forever so we can verify it gets aborted
let handle = tokio::spawn(async {
tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await;
});
state.current_task = Some(handle);
state.is_generating = true;
state.button_state = ButtonState::Stop;
state.abort_current_task();
assert!(state.current_task.is_none());
assert!(!state.is_generating);
assert_eq!(state.status_message, "Generation stopped");
}
}

View file

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
// --- Types ---
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
@ -37,3 +39,46 @@ pub struct ModelInfo {
pub struct ModelsResponse {
pub models: Vec<ModelInfo>,
}
#[cfg(test)]
mod tests {
use super::*;
// `cargo test` will run these. This is the standard Rust pattern:
// put tests in a `#[cfg(test)]` module so they're compiled only when testing.
#[test]
fn chat_message_serializes_and_deserializes() {
// Verify that serde roundtrips correctly — if you ever rename a field
// or change its type, this test catches it before it reaches the API.
let msg = ChatMessage {
role: "user".to_string(),
content: "Hello, world!".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
let decoded: ChatMessage = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.role, msg.role);
assert_eq!(decoded.content, msg.content);
}
#[test]
fn chat_request_includes_stream_flag() {
let req = ChatRequest {
model: "llama3".to_string(),
messages: vec![],
stream: true,
};
let json = serde_json::to_string(&req).unwrap();
// The Ollama API requires `"stream": true` in the body.
assert!(json.contains("\"stream\":true"));
}
#[test]
fn stream_response_parses_done_flag() {
// Real payload shape from the Ollama streaming API.
let raw = r#"{"model":"llama3","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hi"},"done":true}"#;
let resp: StreamResponse = serde_json::from_str(raw).unwrap();
assert!(resp.done);
assert_eq!(resp.message.content, "Hi");
}
}

View file

@ -76,15 +76,16 @@ impl ChatView {
}
// Add sender label with bold formatting
let sender_tag = gtk4::TextTag::new(Some("sender"));
sender_tag.set_weight(700);
sender_tag.set_property("pixels-below-lines", 4);
// Add the sender tag to the buffer's tag table if it's not already there
let tag_table = self.text_buffer.tag_table();
if tag_table.lookup("sender").is_none() {
tag_table.add(&sender_tag);
}
let sender_tag = if let Some(existing) = tag_table.lookup("sender") {
existing
} else {
let tag = gtk4::TextTag::new(Some("sender"));
tag.set_weight(700);
tag.set_property("pixels-below-lines", 4);
tag_table.add(&tag);
tag
};
self.text_buffer.insert_with_tags(&mut end_iter, &format!("{}:\n", sender), &[&sender_tag]);
end_iter = self.text_buffer.end_iter();
@ -133,9 +134,6 @@ impl ChatView {
// Delete existing content from mark to end
self.text_buffer.delete(&mut start_iter, &mut end_iter);
// Get a fresh iterator at the mark position after deletion
let _insert_iter = self.text_buffer.iter_at_mark(mark);
// Render markdown directly to the main buffer
// We use a separate method to avoid conflicts with the borrow checker
self.render_markdown_at_mark(mark, accumulated_content, config);

View file

@ -116,23 +116,12 @@ fn set_generating_state(
shared_state: &SharedState,
controls: &ControlsArea,
button: &gtk4::Button,
generating: bool
generating: bool,
) {
{
let mut state = shared_state.borrow_mut();
state.set_generating(generating);
state.set_status(if generating {
"Assistant is typing...".to_string()
} else {
"Ready".to_string()
});
}
let status = if generating { "Assistant is typing..." } else { "Ready" };
shared_state.borrow_mut().set_generating(generating);
update_button_state(shared_state, button);
controls.set_status(if generating {
"Assistant is typing..."
} else {
"Ready"
});
controls.set_status(status);
}
fn update_button_state(shared_state: &SharedState, button: &gtk4::Button) {
@ -177,12 +166,23 @@ fn start_streaming_task(
model: String,
) {
let (content_sender, content_receiver) = async_channel::bounded::<String>(100);
let (result_sender, result_receiver) = async_channel::bounded(1);
let (result_sender, result_receiver) = async_channel::bounded::<Result<String, crate::api::ApiError>>(1);
// Extract data from shared state for API call
let (conversation, ollama_url) = {
// Extract data from shared state for API call.
// Only send the most recent `max_context_messages` turns to stay within the model's
// context window. Prepend the system prompt (if set) as the first message.
let (messages, ollama_url, batch_size, batch_timeout_ms) = {
let state = shared_state.borrow();
(state.conversation.clone(), state.ollama_url.clone())
let max = state.config.ollama.max_context_messages;
let skip = state.conversation.len().saturating_sub(max);
let mut msgs: Vec<_> = state.conversation[skip..].to_vec();
if let Some(ref prompt) = state.system_prompt {
msgs.insert(0, crate::types::ChatMessage {
role: "system".to_string(),
content: prompt.clone(),
});
}
(msgs, state.ollama_url.clone(), state.config.streaming.batch_size, state.config.streaming.batch_timeout_ms)
};
// Spawn API task
@ -190,8 +190,10 @@ fn start_streaming_task(
let result = api::send_chat_request_streaming(
&ollama_url,
&model,
&std::sync::Arc::new(std::sync::Mutex::new(conversation)),
messages,
content_sender,
batch_size,
batch_timeout_ms,
).await;
let _ = result_sender.send(result).await;
});
@ -216,7 +218,7 @@ fn setup_streaming_handlers(
controls: &ControlsArea,
button: &gtk4::Button,
content_receiver: async_channel::Receiver<String>,
result_receiver: async_channel::Receiver<Result<(String, Option<String>), Box<dyn std::error::Error + Send + Sync>>>,
result_receiver: async_channel::Receiver<Result<String, crate::api::ApiError>>,
) {
// Setup UI structure for streaming
let mut end_iter = chat_view.buffer().end_iter();
@ -241,7 +243,6 @@ fn setup_streaming_handlers(
accumulated_content.push_str(&content_batch);
let config = shared_state_streaming.borrow().config.clone();
chat_view_content.update_streaming_markdown(&response_mark_clone, &accumulated_content, &config);
chat_view_content.scroll_to_bottom();
}
});
@ -258,10 +259,10 @@ fn setup_streaming_handlers(
Ok(response_text) => {
// Apply final markdown formatting
let config = shared_state_final.borrow().config.clone();
chat_view_final.insert_formatted_at_mark(&response_mark, &response_text.0, &config);
chat_view_final.insert_formatted_at_mark(&response_mark, &response_text, &config);
// Update conversation state
shared_state_final.borrow_mut().add_assistant_message(response_text.0);
shared_state_final.borrow_mut().add_assistant_message(response_text);
set_generating_state(&shared_state_final, &controls_final, &button_final, false);
}
Err(e) => {