fix start_streaming_task conversation Vec, remove dead code, plumbing for extensions
This commit is contained in:
parent
dc65c3274f
commit
2a75c9064f
8 changed files with 123 additions and 78 deletions
|
|
@ -15,3 +15,4 @@ dirs = "6.0"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
async-channel = "2.3"
|
async-channel = "2.3"
|
||||||
pulldown-cmark = { version = "0.13.0", default-features = false, features = ["html"] }
|
pulldown-cmark = { version = "0.13.0", default-features = false, features = ["html"] }
|
||||||
|
thiserror = "2"
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,10 @@ stop_button = "#dc3545"
|
||||||
[ollama]
|
[ollama]
|
||||||
url = "http://localhost:11434"
|
url = "http://localhost:11434"
|
||||||
timeout_seconds = 120
|
timeout_seconds = 120
|
||||||
|
|
||||||
|
[streaming]
|
||||||
|
batch_size = 20
|
||||||
|
batch_timeout_ms = 100
|
||||||
```
|
```
|
||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
|
||||||
47
src/api.rs
47
src/api.rs
|
|
@ -1,13 +1,29 @@
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use tokio::time::{timeout, Duration};
|
use tokio::time::{timeout, Duration};
|
||||||
use crate::types::{ChatMessage, ChatRequest, ModelInfo, ModelsResponse, StreamResponse};
|
use crate::types::{ChatMessage, ChatRequest, ModelInfo, ModelsResponse, StreamResponse};
|
||||||
|
|
||||||
pub async fn fetch_models(base_url: &str) -> Result<Vec<ModelInfo>, Box<dyn std::error::Error + Send + Sync>> {
|
/// Typed errors for the Ollama API layer. Using `thiserror` means callers can match
|
||||||
|
/// on exactly what went wrong instead of downcasting a `Box<dyn Error>`.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ApiError {
|
||||||
|
#[error("HTTP request failed: {0}")]
|
||||||
|
Http(#[from] reqwest::Error),
|
||||||
|
#[error("Request timed out")]
|
||||||
|
Timeout,
|
||||||
|
#[error("Server returned error status {0}")]
|
||||||
|
BadStatus(u16),
|
||||||
|
#[error("Failed to parse response: {0}")]
|
||||||
|
Parse(#[from] serde_json::Error),
|
||||||
|
#[error("Model returned empty response")]
|
||||||
|
EmptyResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn fetch_models(base_url: &str) -> Result<Vec<ModelInfo>, ApiError> {
|
||||||
let url = format!("{}/api/tags", base_url);
|
let url = format!("{}/api/tags", base_url);
|
||||||
|
|
||||||
// Add timeout to prevent hanging
|
let response = timeout(Duration::from_secs(10), reqwest::get(&url))
|
||||||
let response = timeout(Duration::from_secs(10), reqwest::get(&url)).await??;
|
.await
|
||||||
|
.map_err(|_| ApiError::Timeout)??;
|
||||||
let models_response: ModelsResponse = response.json().await?;
|
let models_response: ModelsResponse = response.json().await?;
|
||||||
Ok(models_response.models)
|
Ok(models_response.models)
|
||||||
}
|
}
|
||||||
|
|
@ -15,13 +31,11 @@ pub async fn fetch_models(base_url: &str) -> Result<Vec<ModelInfo>, Box<dyn std:
|
||||||
pub async fn send_chat_request_streaming(
|
pub async fn send_chat_request_streaming(
|
||||||
base_url: &str,
|
base_url: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
conversation: &Arc<Mutex<Vec<ChatMessage>>>,
|
messages: Vec<ChatMessage>,
|
||||||
token_sender: async_channel::Sender<String>,
|
token_sender: async_channel::Sender<String>,
|
||||||
) -> Result<(String, Option<String>), Box<dyn std::error::Error + Send + Sync>> {
|
batch_size: usize,
|
||||||
let messages = {
|
batch_timeout_ms: u64,
|
||||||
let conversation = conversation.lock().unwrap();
|
) -> Result<String, ApiError> {
|
||||||
conversation.iter().cloned().collect::<Vec<_>>()
|
|
||||||
};
|
|
||||||
|
|
||||||
let request = ChatRequest {
|
let request = ChatRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
|
|
@ -42,15 +56,14 @@ pub async fn send_chat_request_streaming(
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
return Err(format!("API request failed with status: {}", response.status()).into());
|
return Err(ApiError::BadStatus(response.status().as_u16()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
let mut full_response = String::new();
|
let mut full_response = String::new();
|
||||||
let mut current_batch = String::new();
|
let mut current_batch = String::new();
|
||||||
let mut tokens_since_last_send = 0;
|
let mut tokens_since_last_send = 0;
|
||||||
const BATCH_SIZE: usize = 20;
|
let batch_timeout = Duration::from_millis(batch_timeout_ms);
|
||||||
const BATCH_TIMEOUT: Duration = Duration::from_millis(100);
|
|
||||||
|
|
||||||
let mut last_send = tokio::time::Instant::now();
|
let mut last_send = tokio::time::Instant::now();
|
||||||
|
|
||||||
|
|
@ -74,8 +87,8 @@ pub async fn send_chat_request_streaming(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send batch if conditions are met
|
// Send batch if conditions are met
|
||||||
let should_send = tokens_since_last_send >= BATCH_SIZE
|
let should_send = tokens_since_last_send >= batch_size
|
||||||
|| last_send.elapsed() >= BATCH_TIMEOUT
|
|| last_send.elapsed() >= batch_timeout
|
||||||
|| stream_response.done;
|
|| stream_response.done;
|
||||||
|
|
||||||
if should_send {
|
if should_send {
|
||||||
|
|
@ -115,8 +128,8 @@ pub async fn send_chat_request_streaming(
|
||||||
drop(token_sender);
|
drop(token_sender);
|
||||||
|
|
||||||
if full_response.is_empty() {
|
if full_response.is_empty() {
|
||||||
return Err("No response received from the model".into());
|
return Err(ApiError::EmptyResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((full_response, None))
|
Ok(full_response)
|
||||||
}
|
}
|
||||||
|
|
@ -7,6 +7,7 @@ pub struct Config {
|
||||||
pub ui: UiConfig,
|
pub ui: UiConfig,
|
||||||
pub colors: ColorConfig,
|
pub colors: ColorConfig,
|
||||||
pub ollama: OllamaConfig,
|
pub ollama: OllamaConfig,
|
||||||
|
pub streaming: StreamingConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -39,6 +40,20 @@ pub struct ColorConfig {
|
||||||
pub struct OllamaConfig {
|
pub struct OllamaConfig {
|
||||||
pub url: String,
|
pub url: String,
|
||||||
pub timeout_seconds: u64,
|
pub timeout_seconds: u64,
|
||||||
|
/// Maximum number of conversation turns sent to the model (most recent N messages).
|
||||||
|
/// Keeps context within the model's limit. Set higher for longer memory.
|
||||||
|
pub max_context_messages: usize,
|
||||||
|
/// Optional system prompt prepended to every conversation.
|
||||||
|
/// Leave empty ("") to disable. RAG can override this at runtime via AppState.
|
||||||
|
pub system_prompt: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct StreamingConfig {
|
||||||
|
/// Number of tokens to accumulate before flushing to the UI.
|
||||||
|
pub batch_size: usize,
|
||||||
|
/// Maximum milliseconds to wait before flushing a partial batch.
|
||||||
|
pub batch_timeout_ms: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
|
|
@ -47,6 +62,7 @@ impl Default for Config {
|
||||||
ui: UiConfig::default(),
|
ui: UiConfig::default(),
|
||||||
colors: ColorConfig::default(),
|
colors: ColorConfig::default(),
|
||||||
ollama: OllamaConfig::default(),
|
ollama: OllamaConfig::default(),
|
||||||
|
streaming: StreamingConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -88,6 +104,17 @@ impl Default for OllamaConfig {
|
||||||
Self {
|
Self {
|
||||||
url: "http://localhost:11434".to_string(),
|
url: "http://localhost:11434".to_string(),
|
||||||
timeout_seconds: 120,
|
timeout_seconds: 120,
|
||||||
|
max_context_messages: 20,
|
||||||
|
system_prompt: String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for StreamingConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
batch_size: 20,
|
||||||
|
batch_timeout_ms: 100,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ pub struct MarkdownRenderer {
|
||||||
tags_setup: bool,
|
tags_setup: bool,
|
||||||
// State for streaming think tag processing
|
// State for streaming think tag processing
|
||||||
in_think_tag: bool,
|
in_think_tag: bool,
|
||||||
think_buffer: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MarkdownRenderer {
|
impl MarkdownRenderer {
|
||||||
|
|
@ -47,7 +46,6 @@ impl MarkdownRenderer {
|
||||||
format_stack: Vec::new(),
|
format_stack: Vec::new(),
|
||||||
tags_setup: false,
|
tags_setup: false,
|
||||||
in_think_tag: false,
|
in_think_tag: false,
|
||||||
think_buffer: String::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -203,7 +201,6 @@ impl MarkdownRenderer {
|
||||||
|
|
||||||
// Reset think state
|
// Reset think state
|
||||||
self.in_think_tag = false;
|
self.in_think_tag = false;
|
||||||
self.think_buffer.clear();
|
|
||||||
|
|
||||||
// Continue with text after closing tag
|
// Continue with text after closing tag
|
||||||
remaining = &remaining[end_pos + 8..]; // 8 = "</think>".len()
|
remaining = &remaining[end_pos + 8..]; // 8 = "</think>".len()
|
||||||
|
|
@ -222,7 +219,6 @@ impl MarkdownRenderer {
|
||||||
|
|
||||||
// Start think mode and show the think indicator
|
// Start think mode and show the think indicator
|
||||||
self.in_think_tag = true;
|
self.in_think_tag = true;
|
||||||
self.think_buffer.clear();
|
|
||||||
buffer.insert(iter, "\n💭 ");
|
buffer.insert(iter, "\n💭 ");
|
||||||
|
|
||||||
// Continue with content after opening tag
|
// Continue with content after opening tag
|
||||||
|
|
|
||||||
34
src/state.rs
34
src/state.rs
|
|
@ -6,29 +6,23 @@ use crate::config::Config;
|
||||||
|
|
||||||
pub type SharedState = Rc<RefCell<AppState>>;
|
pub type SharedState = Rc<RefCell<AppState>>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
/// Application-level errors. Uses `thiserror` so each variant gets a clear, typed
|
||||||
|
/// message without boilerplate. Callers can match on the variant to handle errors
|
||||||
|
/// differently (e.g. show a dialog for Config vs. a status-bar message for Api).
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum AppError {
|
pub enum AppError {
|
||||||
|
#[error("API error: {0}")]
|
||||||
Api(String),
|
Api(String),
|
||||||
|
#[error("UI error: {0}")]
|
||||||
Ui(String),
|
Ui(String),
|
||||||
|
#[error("State error: {0}")]
|
||||||
State(String),
|
State(String),
|
||||||
|
#[error("Validation error: {0}")]
|
||||||
Validation(String),
|
Validation(String),
|
||||||
|
#[error("Config error: {0}")]
|
||||||
Config(String),
|
Config(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for AppError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
AppError::Api(msg) => write!(f, "API Error: {}", msg),
|
|
||||||
AppError::Ui(msg) => write!(f, "UI Error: {}", msg),
|
|
||||||
AppError::State(msg) => write!(f, "State Error: {}", msg),
|
|
||||||
AppError::Validation(msg) => write!(f, "Validation Error: {}", msg),
|
|
||||||
AppError::Config(msg) => write!(f, "Config Error: {}", msg),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for AppError {}
|
|
||||||
|
|
||||||
pub type AppResult<T> = Result<T, AppError>;
|
pub type AppResult<T> = Result<T, AppError>;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
|
@ -45,6 +39,9 @@ pub struct AppState {
|
||||||
pub current_task: Option<JoinHandle<()>>,
|
pub current_task: Option<JoinHandle<()>>,
|
||||||
pub selected_model: Option<String>,
|
pub selected_model: Option<String>,
|
||||||
pub status_message: String,
|
pub status_message: String,
|
||||||
|
/// System prompt prepended to every request. Initialized from config but can be
|
||||||
|
/// overridden at runtime (e.g. by a RAG pipeline to inject retrieved context).
|
||||||
|
pub system_prompt: Option<String>,
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -55,6 +52,12 @@ impl Default for AppState {
|
||||||
Config::default()
|
Config::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let system_prompt = if config.ollama.system_prompt.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(config.ollama.system_prompt.clone())
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
conversation: Vec::new(),
|
conversation: Vec::new(),
|
||||||
ollama_url: config.ollama.url.clone(),
|
ollama_url: config.ollama.url.clone(),
|
||||||
|
|
@ -63,6 +66,7 @@ impl Default for AppState {
|
||||||
current_task: None,
|
current_task: None,
|
||||||
selected_model: None,
|
selected_model: None,
|
||||||
status_message: "Ready".to_string(),
|
status_message: "Ready".to_string(),
|
||||||
|
system_prompt,
|
||||||
config,
|
config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -76,15 +76,16 @@ impl ChatView {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add sender label with bold formatting
|
// Add sender label with bold formatting
|
||||||
let sender_tag = gtk4::TextTag::new(Some("sender"));
|
|
||||||
sender_tag.set_weight(700);
|
|
||||||
sender_tag.set_property("pixels-below-lines", 4);
|
|
||||||
|
|
||||||
// Add the sender tag to the buffer's tag table if it's not already there
|
|
||||||
let tag_table = self.text_buffer.tag_table();
|
let tag_table = self.text_buffer.tag_table();
|
||||||
if tag_table.lookup("sender").is_none() {
|
let sender_tag = if let Some(existing) = tag_table.lookup("sender") {
|
||||||
tag_table.add(&sender_tag);
|
existing
|
||||||
}
|
} else {
|
||||||
|
let tag = gtk4::TextTag::new(Some("sender"));
|
||||||
|
tag.set_weight(700);
|
||||||
|
tag.set_property("pixels-below-lines", 4);
|
||||||
|
tag_table.add(&tag);
|
||||||
|
tag
|
||||||
|
};
|
||||||
|
|
||||||
self.text_buffer.insert_with_tags(&mut end_iter, &format!("{}:\n", sender), &[&sender_tag]);
|
self.text_buffer.insert_with_tags(&mut end_iter, &format!("{}:\n", sender), &[&sender_tag]);
|
||||||
end_iter = self.text_buffer.end_iter();
|
end_iter = self.text_buffer.end_iter();
|
||||||
|
|
@ -133,9 +134,6 @@ impl ChatView {
|
||||||
// Delete existing content from mark to end
|
// Delete existing content from mark to end
|
||||||
self.text_buffer.delete(&mut start_iter, &mut end_iter);
|
self.text_buffer.delete(&mut start_iter, &mut end_iter);
|
||||||
|
|
||||||
// Get a fresh iterator at the mark position after deletion
|
|
||||||
let _insert_iter = self.text_buffer.iter_at_mark(mark);
|
|
||||||
|
|
||||||
// Render markdown directly to the main buffer
|
// Render markdown directly to the main buffer
|
||||||
// We use a separate method to avoid conflicts with the borrow checker
|
// We use a separate method to avoid conflicts with the borrow checker
|
||||||
self.render_markdown_at_mark(mark, accumulated_content, config);
|
self.render_markdown_at_mark(mark, accumulated_content, config);
|
||||||
|
|
|
||||||
|
|
@ -116,23 +116,12 @@ fn set_generating_state(
|
||||||
shared_state: &SharedState,
|
shared_state: &SharedState,
|
||||||
controls: &ControlsArea,
|
controls: &ControlsArea,
|
||||||
button: >k4::Button,
|
button: >k4::Button,
|
||||||
generating: bool
|
generating: bool,
|
||||||
) {
|
) {
|
||||||
{
|
let status = if generating { "Assistant is typing..." } else { "Ready" };
|
||||||
let mut state = shared_state.borrow_mut();
|
shared_state.borrow_mut().set_generating(generating);
|
||||||
state.set_generating(generating);
|
|
||||||
state.set_status(if generating {
|
|
||||||
"Assistant is typing...".to_string()
|
|
||||||
} else {
|
|
||||||
"Ready".to_string()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
update_button_state(shared_state, button);
|
update_button_state(shared_state, button);
|
||||||
controls.set_status(if generating {
|
controls.set_status(status);
|
||||||
"Assistant is typing..."
|
|
||||||
} else {
|
|
||||||
"Ready"
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_button_state(shared_state: &SharedState, button: >k4::Button) {
|
fn update_button_state(shared_state: &SharedState, button: >k4::Button) {
|
||||||
|
|
@ -177,12 +166,23 @@ fn start_streaming_task(
|
||||||
model: String,
|
model: String,
|
||||||
) {
|
) {
|
||||||
let (content_sender, content_receiver) = async_channel::bounded::<String>(100);
|
let (content_sender, content_receiver) = async_channel::bounded::<String>(100);
|
||||||
let (result_sender, result_receiver) = async_channel::bounded(1);
|
let (result_sender, result_receiver) = async_channel::bounded::<Result<String, crate::api::ApiError>>(1);
|
||||||
|
|
||||||
// Extract data from shared state for API call
|
// Extract data from shared state for API call.
|
||||||
let (conversation, ollama_url) = {
|
// Only send the most recent `max_context_messages` turns to stay within the model's
|
||||||
|
// context window. Prepend the system prompt (if set) as the first message.
|
||||||
|
let (messages, ollama_url, batch_size, batch_timeout_ms) = {
|
||||||
let state = shared_state.borrow();
|
let state = shared_state.borrow();
|
||||||
(state.conversation.clone(), state.ollama_url.clone())
|
let max = state.config.ollama.max_context_messages;
|
||||||
|
let skip = state.conversation.len().saturating_sub(max);
|
||||||
|
let mut msgs: Vec<_> = state.conversation[skip..].to_vec();
|
||||||
|
if let Some(ref prompt) = state.system_prompt {
|
||||||
|
msgs.insert(0, crate::types::ChatMessage {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: prompt.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(msgs, state.ollama_url.clone(), state.config.streaming.batch_size, state.config.streaming.batch_timeout_ms)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Spawn API task
|
// Spawn API task
|
||||||
|
|
@ -190,8 +190,10 @@ fn start_streaming_task(
|
||||||
let result = api::send_chat_request_streaming(
|
let result = api::send_chat_request_streaming(
|
||||||
&ollama_url,
|
&ollama_url,
|
||||||
&model,
|
&model,
|
||||||
&std::sync::Arc::new(std::sync::Mutex::new(conversation)),
|
messages,
|
||||||
content_sender,
|
content_sender,
|
||||||
|
batch_size,
|
||||||
|
batch_timeout_ms,
|
||||||
).await;
|
).await;
|
||||||
let _ = result_sender.send(result).await;
|
let _ = result_sender.send(result).await;
|
||||||
});
|
});
|
||||||
|
|
@ -216,7 +218,7 @@ fn setup_streaming_handlers(
|
||||||
controls: &ControlsArea,
|
controls: &ControlsArea,
|
||||||
button: >k4::Button,
|
button: >k4::Button,
|
||||||
content_receiver: async_channel::Receiver<String>,
|
content_receiver: async_channel::Receiver<String>,
|
||||||
result_receiver: async_channel::Receiver<Result<(String, Option<String>), Box<dyn std::error::Error + Send + Sync>>>,
|
result_receiver: async_channel::Receiver<Result<String, crate::api::ApiError>>,
|
||||||
) {
|
) {
|
||||||
// Setup UI structure for streaming
|
// Setup UI structure for streaming
|
||||||
let mut end_iter = chat_view.buffer().end_iter();
|
let mut end_iter = chat_view.buffer().end_iter();
|
||||||
|
|
@ -258,10 +260,10 @@ fn setup_streaming_handlers(
|
||||||
Ok(response_text) => {
|
Ok(response_text) => {
|
||||||
// Apply final markdown formatting
|
// Apply final markdown formatting
|
||||||
let config = shared_state_final.borrow().config.clone();
|
let config = shared_state_final.borrow().config.clone();
|
||||||
chat_view_final.insert_formatted_at_mark(&response_mark, &response_text.0, &config);
|
chat_view_final.insert_formatted_at_mark(&response_mark, &response_text, &config);
|
||||||
|
|
||||||
// Update conversation state
|
// Update conversation state
|
||||||
shared_state_final.borrow_mut().add_assistant_message(response_text.0);
|
shared_state_final.borrow_mut().add_assistant_message(response_text);
|
||||||
set_generating_state(&shared_state_final, &controls_final, &button_final, false);
|
set_generating_state(&shared_state_final, &controls_final, &button_final, false);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue