Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions crates/goose/src/security/classification_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type ClassificationResponse = Vec<Vec<ClassificationLabel>>;
#[derive(Debug, Deserialize, Clone)]
pub struct ModelEndpointInfo {
pub endpoint: String,
pub model_type: Option<String>,
#[serde(flatten)]
pub extra_params: HashMap<String, serde_json::Value>,
}
Expand Down Expand Up @@ -75,7 +76,7 @@ impl ClassificationClient {
model_name
))?;

tracing::info!(
tracing::debug!(
model_name = %model_name,
endpoint = %model_info.endpoint,
extra_params = ?model_info.extra_params,
Expand All @@ -90,6 +91,30 @@ impl ClassificationClient {
)
}

pub fn from_model_type(model_type: &str, timeout_ms: Option<u64>) -> Result<Self> {
let mapping = serde_json::from_str::<ModelMappingConfig>(
&std::env::var("SECURITY_ML_MODEL_MAPPING")
.context("SECURITY_ML_MODEL_MAPPING environment variable not set")?,
)
.context("Failed to parse SECURITY_ML_MODEL_MAPPING JSON")?;

let (_, model_info) = mapping
.models
.iter()
.find(|(_, info)| info.model_type.as_deref() == Some(model_type))
.context(format!(
"No model with type '{}' found in SECURITY_ML_MODEL_MAPPING",
model_type
))?;

Self::new(
model_info.endpoint.clone(),
timeout_ms,
None,
Some(model_info.extra_params.clone()),
)
}
Comment on lines +94 to +116
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new from_model_type method lacks test coverage. Consider adding tests for: 1) successfully finding a model by type, 2) handling cases where no model with the specified type exists, 3) handling invalid JSON in SECURITY_ML_MODEL_MAPPING.

Copilot uses AI. Check for mistakes.

pub fn from_endpoint(
endpoint_url: String,
timeout_ms: Option<u64>,
Expand All @@ -104,7 +129,7 @@ impl ClassificationClient {
.map(|t| t.trim().to_string())
.filter(|t| !t.is_empty());

tracing::info!(
tracing::debug!(
endpoint = %endpoint_url,
has_token = auth_token.is_some(),
"Creating classification client from endpoint"
Expand All @@ -114,12 +139,6 @@ impl ClassificationClient {
}

pub async fn classify(&self, text: &str) -> Result<f32> {
tracing::debug!(
endpoint = %self.endpoint_url,
text_length = text.len(),
"Sending classification request"
);

let parameters = self
.extra_params
.as_ref()
Expand Down Expand Up @@ -197,14 +216,6 @@ impl ClassificationClient {
}
};

tracing::info!(
injection_score = %injection_score,
top_label = %top_label.label,
top_score = %top_label.score,
normalized = !is_probabilities,
"Classification complete"
);

Ok(injection_score)
}

Expand Down
10 changes: 5 additions & 5 deletions crates/goose/src/security/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ impl SecurityManager {
Ok(s) => {
tracing::info!(
counter.goose.prompt_injection_scanner_enabled = 1,
"🔓 Security scanner initialized with ML-based detection"
"Security scanner initialized with ML-based detection"
);
s
}
Err(e) => {
let error_chain = format!("{:#}", e);
tracing::warn!(
"⚠️ ML scanning requested but failed to initialize. Falling back to pattern-only scanning.\n\nError details:\n{}",
"ML scanning requested but failed to initialize. Falling back to pattern-only scanning.\n\nError details:\n{}",
error_chain
);
PromptInjectionScanner::new()
Expand All @@ -85,7 +85,7 @@ impl SecurityManager {
} else {
tracing::info!(
counter.goose.prompt_injection_scanner_enabled = 1,
"🔓 Security scanner initialized with pattern-based detection only"
"Security scanner initialized with pattern-based detection only"
);
PromptInjectionScanner::new()
};
Expand All @@ -95,8 +95,8 @@ impl SecurityManager {

let mut results = Vec::new();

tracing::info!(
"🔍 Starting security analysis - {} tool requests, {} messages",
tracing::debug!(
"Starting security analysis - {} tool requests, {} messages",
tool_requests.len(),
messages.len()
);
Expand Down
Loading
Loading