diff --git a/README.md b/README.md index def9c1e..c53f99b 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,7 @@ All request information is available via the `r` object: - `r.host` - Hostname from the URL - `r.scheme` - URL scheme (http or https) - `r.path` - Path portion of the URL +- `r.requester_ip` - IP address of the client making the request - `r.block_message` - Optional message to set when denying (writable) **JavaScript evaluation rules:** @@ -224,6 +225,7 @@ If `--sh` has spaces, it's run through `sh`; otherwise it's executed directly. - `HTTPJAIL_HOST` - Hostname from the URL - `HTTPJAIL_SCHEME` - URL scheme (http or https) - `HTTPJAIL_PATH` - Path component of the URL +- `HTTPJAIL_REQUESTER_IP` - IP address of the client making the request **Script requirements:** @@ -236,7 +238,6 @@ If `--sh` has spaces, it's run through `sh`; otherwise it's executed directly. > Script-based evaluation can also be used for custom logging! Your script can log requests to a database, send metrics to a monitoring service, or implement complex audit trails before returning the allow/deny decision. ## Advanced Options - ```bash # Verbose logging httpjail -vvv --js "true" -- curl https://example.com diff --git a/src/proxy.rs b/src/proxy.rs index aa993c5..895688c 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -293,7 +293,8 @@ impl ProxyServer { tokio::spawn(async move { if let Err(e) = - handle_http_connection(stream, rule_engine, cert_manager).await + handle_http_connection(stream, rule_engine, cert_manager, addr) + .await { error!("Error handling HTTP connection: {:?}", e); } @@ -335,7 +336,8 @@ impl ProxyServer { tokio::spawn(async move { if let Err(e) = - handle_https_connection(stream, rule_engine, cert_manager).await + handle_https_connection(stream, rule_engine, cert_manager, addr) + .await { error!("Error handling HTTPS connection: {:?}", e); } @@ -364,10 +366,16 @@ async fn handle_http_connection( stream: TcpStream, rule_engine: Arc, cert_manager: Arc, + remote_addr: SocketAddr, ) -> Result<()> { let io = TokioIo::new(stream); let service = service_fn(move |req| { - handle_http_request(req, Arc::clone(&rule_engine), Arc::clone(&cert_manager)) + handle_http_request( + req, + Arc::clone(&rule_engine), + Arc::clone(&cert_manager), + remote_addr, + ) }); http1::Builder::new() @@ -383,15 +391,17 @@ async fn handle_https_connection( stream: TcpStream, rule_engine: Arc, cert_manager: Arc, + remote_addr: SocketAddr, ) -> Result<()> { // Delegate to the TLS-specific module - crate::proxy_tls::handle_https_connection(stream, rule_engine, cert_manager).await + crate::proxy_tls::handle_https_connection(stream, rule_engine, cert_manager, remote_addr).await } pub async fn handle_http_request( req: Request, rule_engine: Arc, _cert_manager: Arc, + remote_addr: SocketAddr, ) -> Result>, std::convert::Infallible> { let method = req.method().clone(); let uri = req.uri().clone(); @@ -412,10 +422,16 @@ pub async fn handle_http_request( format!("http://{}{}", host, path) }; - debug!("Proxying HTTP request: {} {}", method, full_url); + debug!( + "Proxying HTTP request: {} {} from {}", + method, full_url, remote_addr + ); - // Evaluate rules with method - let evaluation = rule_engine.evaluate_with_context(method, &full_url).await; + // Evaluate rules with method and requester IP + let requester_ip = remote_addr.ip().to_string(); + let evaluation = rule_engine + .evaluate_with_context_and_ip(method, &full_url, &requester_ip) + .await; match evaluation.action { Action::Allow => { debug!("Request allowed: {}", full_url); diff --git a/src/proxy_tls.rs b/src/proxy_tls.rs index b58fe0e..f997d00 100644 --- a/src/proxy_tls.rs +++ b/src/proxy_tls.rs @@ -37,8 +37,9 @@ pub async fn handle_https_connection( stream: TcpStream, rule_engine: Arc, cert_manager: Arc, + remote_addr: std::net::SocketAddr, ) -> Result<()> { - debug!("Handling new HTTPS connection"); + debug!("Handling new HTTPS connection from {}", remote_addr); // Peek at the first few bytes to determine if this is HTTP or TLS let mut peek_buf = [0; 6]; @@ -64,18 +65,18 @@ pub async fn handle_https_connection( if peek_buf[0] == 0x16 && n > 1 && (peek_buf[1] == 0x03 || peek_buf[1] == 0x02) { // This is a TLS ClientHello - we're in transparent proxy mode debug!("Detected TLS ClientHello - transparent proxy mode"); - handle_transparent_tls(stream, rule_engine, cert_manager).await + handle_transparent_tls(stream, rule_engine, cert_manager, remote_addr).await } else if peek_buf[0] >= 0x41 && peek_buf[0] <= 0x5A { // This looks like HTTP (starts with uppercase ASCII letter) // Check if it's a CONNECT request let request_str = String::from_utf8_lossy(&peek_buf); if request_str.starts_with("CONNEC") { debug!("Detected CONNECT request - explicit proxy mode"); - handle_connect_tunnel(stream, rule_engine, cert_manager).await + handle_connect_tunnel(stream, rule_engine, cert_manager, remote_addr).await } else { // Regular HTTP on HTTPS port debug!("Detected plain HTTP on HTTPS port"); - handle_plain_http(stream, rule_engine, cert_manager).await + handle_plain_http(stream, rule_engine, cert_manager, remote_addr).await } } else { warn!( @@ -159,6 +160,7 @@ async fn handle_transparent_tls( mut stream: TcpStream, rule_engine: Arc, cert_manager: Arc, + remote_addr: std::net::SocketAddr, ) -> Result<()> { debug!("Handling transparent TLS connection"); @@ -212,7 +214,7 @@ async fn handle_transparent_tls( let io = TokioIo::new(tls_stream); let service = service_fn(move |req| { let host_clone = hostname.clone(); - handle_decrypted_https_request(req, Arc::clone(&rule_engine), host_clone) + handle_decrypted_https_request(req, Arc::clone(&rule_engine), host_clone, remote_addr) }); debug!("Starting HTTP/1.1 server for decrypted requests"); @@ -230,6 +232,7 @@ async fn handle_connect_tunnel( stream: TcpStream, rule_engine: Arc, cert_manager: Arc, + remote_addr: std::net::SocketAddr, ) -> Result<()> { debug!("Handling CONNECT tunnel"); @@ -305,8 +308,9 @@ async fn handle_connect_tunnel( // Check if this host is allowed let full_url = format!("https://{}", target); + let requester_ip = remote_addr.ip().to_string(); let evaluation = rule_engine - .evaluate_with_context(Method::GET, &full_url) + .evaluate_with_context_and_ip(Method::GET, &full_url, &requester_ip) .await; match evaluation.action { Action::Allow => { @@ -337,7 +341,7 @@ async fn handle_connect_tunnel( debug!("Sent 200 Connection Established, starting TLS handshake"); // Now perform TLS handshake with the client - perform_tls_interception(stream, rule_engine, cert_manager, host).await + perform_tls_interception(stream, rule_engine, cert_manager, host, remote_addr).await } Action::Deny => { warn!("CONNECT denied to: {}", host); @@ -372,6 +376,7 @@ async fn perform_tls_interception( rule_engine: Arc, cert_manager: Arc, host: &str, + remote_addr: std::net::SocketAddr, ) -> Result<()> { // Get certificate for the host let (cert_chain, key) = cert_manager @@ -405,9 +410,10 @@ async fn perform_tls_interception( // Now handle the decrypted HTTPS requests let io = TokioIo::new(tls_stream); let host_string = host.to_string(); + let remote_addr_copy = remote_addr; // Copy for the closure let service = service_fn(move |req| { let host_clone = host_string.clone(); - handle_decrypted_https_request(req, Arc::clone(&rule_engine), host_clone) + handle_decrypted_https_request(req, Arc::clone(&rule_engine), host_clone, remote_addr_copy) }); debug!("Starting HTTP/1.1 server for decrypted requests"); @@ -425,12 +431,18 @@ async fn handle_plain_http( stream: TcpStream, rule_engine: Arc, cert_manager: Arc, + remote_addr: std::net::SocketAddr, ) -> Result<()> { debug!("Handling plain HTTP on HTTPS port"); let io = TokioIo::new(stream); let service = service_fn(move |req| { - crate::proxy::handle_http_request(req, Arc::clone(&rule_engine), Arc::clone(&cert_manager)) + crate::proxy::handle_http_request( + req, + Arc::clone(&rule_engine), + Arc::clone(&cert_manager), + remote_addr, + ) }); http1::Builder::new() @@ -447,6 +459,7 @@ async fn handle_decrypted_https_request( req: Request, rule_engine: Arc, host: String, + remote_addr: std::net::SocketAddr, ) -> Result>, std::convert::Infallible> { let method = req.method().clone(); let uri = req.uri().clone(); @@ -455,11 +468,15 @@ async fn handle_decrypted_https_request( let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/"); let full_url = format!("https://{}{}", host, path); - debug!("Proxying HTTPS request: {} {}", method, full_url); + debug!( + "Proxying HTTPS request: {} {} from {}", + method, full_url, remote_addr + ); - // Evaluate rules with method + // Evaluate rules with method and requester IP + let requester_ip = remote_addr.ip().to_string(); let evaluation = rule_engine - .evaluate_with_context(method.clone(), &full_url) + .evaluate_with_context_and_ip(method.clone(), &full_url, &requester_ip) .await; match evaluation.action { Action::Allow => { @@ -671,8 +688,8 @@ mod tests { // Spawn proxy handler tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let _ = handle_connect_tunnel(stream, rule_engine, cert_manager).await; + let (stream, addr) = listener.accept().await.unwrap(); + let _ = handle_connect_tunnel(stream, rule_engine, cert_manager, addr).await; }); // Connect to proxy @@ -706,8 +723,8 @@ mod tests { // Spawn proxy handler tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let _ = handle_connect_tunnel(stream, rule_engine, cert_manager).await; + let (stream, addr) = listener.accept().await.unwrap(); + let _ = handle_connect_tunnel(stream, rule_engine, cert_manager, addr).await; }); // Connect to proxy @@ -743,8 +760,8 @@ mod tests { // Spawn proxy handler tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let _ = handle_transparent_tls(stream, rule_engine, cert_manager).await; + let (stream, addr) = listener.accept().await.unwrap(); + let _ = handle_transparent_tls(stream, rule_engine, cert_manager, addr).await; }); // Connect to proxy with TLS directly (transparent mode) @@ -815,8 +832,8 @@ mod tests { let cert_manager = cert_manager.clone(); let rule_engine = rule_engine.clone(); tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let _ = handle_https_connection(stream, rule_engine, cert_manager).await; + let (stream, addr) = listener.accept().await.unwrap(); + let _ = handle_https_connection(stream, rule_engine, cert_manager, addr).await; }); let mut stream = TcpStream::connect(addr).await.unwrap(); @@ -848,9 +865,9 @@ mod tests { // Start proxy handler tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); + let (stream, addr) = listener.accept().await.unwrap(); // Use the actual transparent TLS handler (which will extract SNI, etc.) - let _ = handle_transparent_tls(stream, rule_engine, cert_manager).await; + let _ = handle_transparent_tls(stream, rule_engine, cert_manager, addr).await; }); // Give the server time to start diff --git a/src/rules.rs b/src/rules.rs index cf3ca3a..6a50618 100644 --- a/src/rules.rs +++ b/src/rules.rs @@ -44,7 +44,7 @@ impl EvaluationResult { #[async_trait] pub trait RuleEngineTrait: Send + Sync { - async fn evaluate(&self, method: Method, url: &str) -> EvaluationResult; + async fn evaluate(&self, method: Method, url: &str, requester_ip: &str) -> EvaluationResult; fn name(&self) -> &str; } @@ -65,8 +65,11 @@ impl LoggingRuleEngine { #[async_trait] impl RuleEngineTrait for LoggingRuleEngine { - async fn evaluate(&self, method: Method, url: &str) -> EvaluationResult { - let result = self.engine.evaluate(method.clone(), url).await; + async fn evaluate(&self, method: Method, url: &str, requester_ip: &str) -> EvaluationResult { + let result = self + .engine + .evaluate(method.clone(), url, requester_ip) + .await; if let Some(log) = &self.request_log && let Ok(mut file) = log.lock() @@ -110,11 +113,24 @@ impl RuleEngine { } pub async fn evaluate(&self, method: Method, url: &str) -> Action { - self.inner.evaluate(method, url).await.action + self.inner.evaluate(method, url, "127.0.0.1").await.action } pub async fn evaluate_with_context(&self, method: Method, url: &str) -> EvaluationResult { - self.inner.evaluate(method, url).await + self.inner.evaluate(method, url, "127.0.0.1").await + } + + pub async fn evaluate_with_ip(&self, method: Method, url: &str, requester_ip: &str) -> Action { + self.inner.evaluate(method, url, requester_ip).await.action + } + + pub async fn evaluate_with_context_and_ip( + &self, + method: Method, + url: &str, + requester_ip: &str, + ) -> EvaluationResult { + self.inner.evaluate(method, url, requester_ip).await } } diff --git a/src/rules/script.rs b/src/rules/script.rs index a588a89..3d10c5d 100644 --- a/src/rules/script.rs +++ b/src/rules/script.rs @@ -15,7 +15,12 @@ impl ScriptRuleEngine { ScriptRuleEngine { script } } - async fn execute_script(&self, method: Method, url: &str) -> (bool, String) { + async fn execute_script( + &self, + method: Method, + url: &str, + requester_ip: &str, + ) -> (bool, String) { let parsed_url = match Url::parse(url) { Ok(u) => u, Err(e) => { @@ -29,8 +34,8 @@ impl ScriptRuleEngine { let path = parsed_url.path(); debug!( - "Executing script for {} {} (host: {}, path: {})", - method, url, host, path + "Executing script for {} {} from {} (host: {}, path: {})", + method, url, requester_ip, host, path ); // Build the command @@ -47,6 +52,7 @@ impl ScriptRuleEngine { .env("HTTPJAIL_SCHEME", scheme) .env("HTTPJAIL_HOST", host) .env("HTTPJAIL_PATH", path) + .env("HTTPJAIL_REQUESTER_IP", requester_ip) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) .kill_on_drop(true); // Ensure child is killed if dropped @@ -98,8 +104,8 @@ impl ScriptRuleEngine { #[async_trait] impl RuleEngineTrait for ScriptRuleEngine { - async fn evaluate(&self, method: Method, url: &str) -> EvaluationResult { - let (allowed, context) = self.execute_script(method.clone(), url).await; + async fn evaluate(&self, method: Method, url: &str, requester_ip: &str) -> EvaluationResult { + let (allowed, context) = self.execute_script(method.clone(), url, requester_ip).await; if allowed { debug!("ALLOW: {} {} (script allowed)", method, url); @@ -152,7 +158,7 @@ exit 0 let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); let result = engine - .evaluate(Method::GET, "https://example.com/test") + .evaluate(Method::GET, "https://example.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); @@ -181,7 +187,7 @@ exit 1 let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); let result = engine - .evaluate(Method::GET, "https://example.com/test") + .evaluate(Method::GET, "https://example.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); @@ -211,7 +217,7 @@ exit 1 let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); let result = engine - .evaluate(Method::GET, "https://example.com/test") + .evaluate(Method::GET, "https://example.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); @@ -247,12 +253,12 @@ fi let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); let result = engine - .evaluate(Method::GET, "https://allowed.com/test") + .evaluate(Method::GET, "https://allowed.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); let result = engine - .evaluate(Method::GET, "https://blocked.com/test") + .evaluate(Method::GET, "https://blocked.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); assert_eq!( @@ -267,12 +273,12 @@ fi let engine = ScriptRuleEngine::new("test \"$HTTPJAIL_HOST\" = \"github.com\"".to_string()); let result = engine - .evaluate(Method::GET, "https://github.com/test") + .evaluate(Method::GET, "https://github.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); let result = engine - .evaluate(Method::GET, "https://example.com/test") + .evaluate(Method::GET, "https://example.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); } diff --git a/src/rules/v8_js.rs b/src/rules/v8_js.rs index 9a0fa23..7476e4f 100644 --- a/src/rules/v8_js.rs +++ b/src/rules/v8_js.rs @@ -63,7 +63,12 @@ impl V8JsRuleEngine { } /// Evaluate the JavaScript rule against the given request - fn execute_js_rule(&self, method: &Method, url: &str) -> (bool, Option) { + fn execute_js_rule( + &self, + method: &Method, + url: &str, + requester_ip: &str, + ) -> (bool, Option) { let parsed_url = match Url::parse(url) { Ok(u) => u, Err(e) => { @@ -83,7 +88,7 @@ impl V8JsRuleEngine { // Create a new isolate and context for this evaluation // This ensures thread safety at the cost of some performance - match self.create_and_execute(method.as_str(), url, scheme, host, path) { + match self.create_and_execute(method.as_str(), url, scheme, host, path, requester_ip) { Ok(result) => result, Err(e) => { warn!("JavaScript execution failed: {}", e); @@ -99,6 +104,7 @@ impl V8JsRuleEngine { scheme: &str, host: &str, path: &str, + requester_ip: &str, ) -> Result<(bool, Option), Box> { let mut isolate = v8::Isolate::new(v8::CreateParams::default()); let handle_scope = &mut v8::HandleScope::new(&mut isolate); @@ -138,6 +144,11 @@ impl V8JsRuleEngine { r_obj.set(context_scope, key.into(), path_str.into()); } + if let Some(ip_str) = v8::String::new(context_scope, requester_ip) { + let key = v8::String::new(context_scope, "requester_ip").unwrap(); + r_obj.set(context_scope, key.into(), ip_str.into()); + } + // Initialize block_message as undefined (can be set by user script) let block_msg_key = v8::String::new(context_scope, "block_message").unwrap(); let undefined_val = v8::undefined(context_scope); @@ -188,16 +199,17 @@ impl V8JsRuleEngine { #[async_trait] impl RuleEngineTrait for V8JsRuleEngine { - async fn evaluate(&self, method: Method, url: &str) -> EvaluationResult { + async fn evaluate(&self, method: Method, url: &str, requester_ip: &str) -> EvaluationResult { // Run the JavaScript evaluation in a blocking task to avoid // issues with V8's single-threaded nature let js_code = self.js_code.clone(); let method_clone = method.clone(); let url_clone = url.to_string(); + let ip_clone = requester_ip.to_string(); let (allowed, block_message) = tokio::task::spawn_blocking(move || { let engine = V8JsRuleEngine { js_code }; - engine.execute_js_rule(&method_clone, &url_clone) + engine.execute_js_rule(&method_clone, &url_clone, &ip_clone) }) .await .unwrap_or_else(|e| { @@ -237,7 +249,7 @@ mod tests { let engine = V8JsRuleEngine::new(js_code).expect("Failed to create JS engine"); let result = engine - .evaluate(Method::GET, "https://github.com/test") + .evaluate(Method::GET, "https://github.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); } @@ -249,7 +261,7 @@ mod tests { let engine = V8JsRuleEngine::new(js_code).expect("Failed to create JS engine"); let result = engine - .evaluate(Method::GET, "https://example.com/test") + .evaluate(Method::GET, "https://example.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); } @@ -261,12 +273,12 @@ mod tests { let engine = V8JsRuleEngine::new(js_code).expect("Failed to create JS engine"); let result = engine - .evaluate(Method::GET, "https://api.github.com/v3") + .evaluate(Method::GET, "https://api.github.com/v3", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); let result = engine - .evaluate(Method::POST, "https://api.github.com/v3") + .evaluate(Method::POST, "https://api.github.com/v3", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); } @@ -278,12 +290,12 @@ mod tests { let engine = V8JsRuleEngine::new(js_code).expect("Failed to create JS engine"); let result = engine - .evaluate(Method::GET, "https://example.com/api/test") + .evaluate(Method::GET, "https://example.com/api/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); let result = engine - .evaluate(Method::GET, "https://example.com/public/test") + .evaluate(Method::GET, "https://example.com/public/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); } @@ -297,25 +309,25 @@ mod tests { // Test GitHub allow let result = engine - .evaluate(Method::GET, "https://github.com/user/repo") + .evaluate(Method::GET, "https://github.com/user/repo", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); // Test social media block let result = engine - .evaluate(Method::GET, "https://facebook.com/profile") + .evaluate(Method::GET, "https://facebook.com/profile", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); // Test API allow let result = engine - .evaluate(Method::POST, "https://example.com/api/data") + .evaluate(Method::POST, "https://example.com/api/data", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); // Test default deny let result = engine - .evaluate(Method::GET, "https://example.com/public") + .evaluate(Method::GET, "https://example.com/public", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); } @@ -338,7 +350,7 @@ mod tests { // Should return deny on runtime error let result = engine - .evaluate(Method::GET, "https://example.com/test") + .evaluate(Method::GET, "https://example.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); } @@ -352,7 +364,7 @@ mod tests { // Should block facebook with custom message let result = engine - .evaluate(Method::GET, "https://facebook.com/test") + .evaluate(Method::GET, "https://facebook.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); assert_eq!( @@ -362,7 +374,7 @@ mod tests { // Should allow others without message let result = engine - .evaluate(Method::GET, "https://example.com/test") + .evaluate(Method::GET, "https://example.com/test", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); assert_eq!(result.context, None); diff --git a/tests/script_integration.rs b/tests/script_integration.rs index ac21e73..09e5e91 100644 --- a/tests/script_integration.rs +++ b/tests/script_integration.rs @@ -2,6 +2,7 @@ use httpjail::rules::script::ScriptRuleEngine; use httpjail::rules::{Action, RuleEngineTrait}; use hyper::Method; use std::fs; +use std::io::Write; use tempfile::NamedTempFile; #[tokio::test] @@ -34,13 +35,13 @@ fi // Test allowed request let result = engine - .evaluate(Method::GET, "https://github.com/user/repo") + .evaluate(Method::GET, "https://github.com/user/repo", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); // Test denied request with context let result = engine - .evaluate(Method::POST, "https://example.com/api") + .evaluate(Method::POST, "https://example.com/api", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); assert_eq!( @@ -82,18 +83,18 @@ fi // Test allowed methods let result = engine - .evaluate(Method::GET, "https://example.com/api") + .evaluate(Method::GET, "https://example.com/api", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); let result = engine - .evaluate(Method::HEAD, "https://example.com/api") + .evaluate(Method::HEAD, "https://example.com/api", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); // Test denied method with context let result = engine - .evaluate(Method::POST, "https://example.com/api") + .evaluate(Method::POST, "https://example.com/api", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); assert_eq!(result.context, Some("Method POST not allowed".to_string())); @@ -110,12 +111,16 @@ async fn test_inline_script_evaluation() { ); let result = engine - .evaluate(Method::GET, "https://example.com/api/v1/health") + .evaluate( + Method::GET, + "https://example.com/api/v1/health", + "127.0.0.1", + ) .await; assert!(matches!(result.action, Action::Allow)); let result = engine - .evaluate(Method::GET, "https://example.com/api/v2/users") + .evaluate(Method::GET, "https://example.com/api/v2/users", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); } @@ -156,7 +161,7 @@ fi // Test allowed GitHub GET let result = engine - .evaluate(Method::GET, "https://github.com/user/repo") + .evaluate(Method::GET, "https://github.com/user/repo", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); assert_eq!( @@ -166,14 +171,14 @@ fi // Test allowed API POST let result = engine - .evaluate(Method::POST, "https://api.example.com/users") + .evaluate(Method::POST, "https://api.example.com/users", "127.0.0.1") .await; assert!(matches!(result.action, Action::Allow)); assert_eq!(result.context, Some("API write access allowed".to_string())); // Test denied request let result = engine - .evaluate(Method::POST, "https://github.com/user/repo") + .evaluate(Method::POST, "https://github.com/user/repo", "127.0.0.1") .await; assert!(matches!(result.action, Action::Deny)); assert!( @@ -186,3 +191,45 @@ fi // TempPath will be automatically deleted when it goes out of scope drop(script_path); } + +#[tokio::test] +async fn test_script_receives_requester_ip() { + // Create a script that logs the requester IP + let mut script = NamedTempFile::new().unwrap(); + let script_content = r#"#!/bin/bash +if [ -n "$HTTPJAIL_REQUESTER_IP" ]; then + echo "Request from IP: $HTTPJAIL_REQUESTER_IP" + exit 0 +else + echo "No requester IP provided" + exit 1 +fi +"#; + script.write_all(script_content.as_bytes()).unwrap(); + script.flush().unwrap(); + + let script_path = script.into_temp_path(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&script_path, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + + let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); + + // Test with a specific IP + let result = engine + .evaluate(Method::GET, "https://example.com", "192.168.1.100") + .await; + + assert!(matches!(result.action, Action::Allow)); + assert!(result.context.is_some()); + assert!( + result + .context + .unwrap() + .contains("Request from IP: 192.168.1.100") + ); + + drop(script_path); +}