Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import datadog.trace.api.DDTags;
import datadog.trace.api.InstrumenterConfig;
import datadog.trace.api.ProductActivation;
import datadog.trace.api.appsec.HttpClientRequest;
import datadog.trace.api.gateway.BlockResponseFunction;
import datadog.trace.api.gateway.Flow;
import datadog.trace.api.gateway.RequestContext;
Expand Down Expand Up @@ -99,7 +100,7 @@ public AgentSpan onRequest(final AgentSpan span, final REQUEST request) {
HTTP_RESOURCE_DECORATOR.withClientPath(span, method, url.getPath());
}
// SSRF exploit prevention check
onNetworkConnection(url.toString());
onHttpClientRequest(span, url.toString());
} else if (shouldSetResourceName()) {
span.setResourceName(DEFAULT_RESOURCE_NAME);
}
Expand Down Expand Up @@ -178,24 +179,20 @@ public long getResponseContentLength(final RESPONSE response) {
return 0;
}

private void onNetworkConnection(final String networkConnection) {
protected void onHttpClientRequest(final AgentSpan span, final String url) {
if (!APPSEC_RASP_ENABLED) {
return;
}
if (networkConnection == null) {
if (url == null) {
return;
}
final BiFunction<RequestContext, String, Flow<Void>> networkConnectionCallback =
final long requestId = span.getSpanId();
final BiFunction<RequestContext, HttpClientRequest, Flow<Void>> requestCb =
AgentTracer.get()
.getCallbackProvider(RequestContextSlot.APPSEC)
.getCallback(EVENTS.networkConnection());
.getCallback(EVENTS.httpClientRequest());

if (networkConnectionCallback == null) {
return;
}

final AgentSpan span = AgentTracer.get().activeSpan();
if (span == null) {
if (requestCb == null) {
return;
}

Expand All @@ -204,7 +201,7 @@ private void onNetworkConnection(final String networkConnection) {
return;
}

Flow<Void> flow = networkConnectionCallback.apply(ctx, networkConnection);
Flow<Void> flow = requestCb.apply(ctx, new HttpClientRequest(requestId, url));
Flow.Action action = flow.getAction();
if (action instanceof Flow.Action.RequestBlockingAction) {
BlockResponseFunction brf = ctx.getBlockResponseFunction();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package datadog.trace.bootstrap.instrumentation.decorator

import datadog.trace.api.DDTags
import datadog.trace.api.appsec.HttpClientRequest
import datadog.trace.api.config.AppSecConfig
import datadog.trace.api.gateway.CallbackProvider
import static datadog.trace.api.gateway.Events.EVENTS
Expand Down Expand Up @@ -249,8 +250,8 @@ class HttpClientDecoratorTest extends ClientDecoratorTest {
decorator.onRequest(span2, req)

then:
1 * callbackProvider.getCallback(EVENTS.networkConnection()) >> listener
1 * listener.apply(reqCtx, _ as String)
1 * callbackProvider.getCallback(EVENTS.httpClientRequest()) >> listener
1 * listener.apply(reqCtx, _ as HttpClientRequest)
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.datadog.appsec;

import com.datadog.appsec.api.security.ApiSecurityDownstreamSampler;
import com.datadog.appsec.api.security.ApiSecuritySampler;
import com.datadog.appsec.api.security.ApiSecuritySamplerImpl;
import com.datadog.appsec.api.security.AppSecSpanPostProcessor;
Expand Down Expand Up @@ -81,11 +82,14 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s
}
sco.createRemaining(config);

final double maxDownstreamRequestsRate =
config.getApiSecurityDownstreamRequestAnalysisSampleRate();
GatewayBridge gatewayBridge =
new GatewayBridge(
gw,
REPLACEABLE_EVENT_PRODUCER,
() -> API_SECURITY_SAMPLER,
ApiSecurityDownstreamSampler.build(maxDownstreamRequestsRate),
APP_SEC_CONFIG_SERVICE.getTraceSegmentPostProcessors());

loadModules(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.datadog.appsec.api.security;

import com.datadog.appsec.gateway.AppSecRequestContext;

public interface ApiSecurityDownstreamSampler {

boolean sampleHttpClientRequest(AppSecRequestContext ctx, long requestId);

boolean isSampled(AppSecRequestContext ctx, long requestId);

ApiSecurityDownstreamSampler INCLUDE_ALL =
new ApiSecurityDownstreamSampler() {
@Override
public boolean sampleHttpClientRequest(AppSecRequestContext ctx, long requestId) {
return true;
}

@Override
public boolean isSampled(AppSecRequestContext ctx, long requestId) {
return true;
}
};

ApiSecurityDownstreamSampler INCLUDE_NONE =
new ApiSecurityDownstreamSampler() {
@Override
public boolean sampleHttpClientRequest(AppSecRequestContext ctx, long requestId) {
return false;
}

@Override
public boolean isSampled(AppSecRequestContext ctx, long requestId) {
return false;
}
};

static ApiSecurityDownstreamSampler build(double rate) {
return rate <= 0D
? INCLUDE_NONE
: (rate >= 1D ? INCLUDE_ALL : new ApiSecurityDownstreamSamplerImpl(rate));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.datadog.appsec.api.security;

import com.datadog.appsec.gateway.AppSecRequestContext;
import java.util.concurrent.atomic.AtomicLong;

public class ApiSecurityDownstreamSamplerImpl implements ApiSecurityDownstreamSampler {

private static final long KNUTH_FACTOR = 1111111111111111111L;
private static final double SAMPLING_MAX = Math.pow(2, 64) - 1;

private final AtomicLong globalRequestCount = new AtomicLong(0);
private final double threshold;

public ApiSecurityDownstreamSamplerImpl(double rate) {
threshold = samplingCutoff(rate);
}

private static double samplingCutoff(double rate) {
if (rate < 0.5) {
return (long) (rate * SAMPLING_MAX) + Long.MIN_VALUE;
}
if (rate < 1.0) {
return (long) ((rate * SAMPLING_MAX) + Long.MIN_VALUE);
}
return Long.MAX_VALUE;
}

/**
* First sample the request to ensure we randomize the request and then check if the current
* server request has budget to analyze the downstream request.
*/
@Override
public boolean sampleHttpClientRequest(final AppSecRequestContext ctx, final long requestId) {
final long counter = updateRequestCount();
if (counter * KNUTH_FACTOR + Long.MIN_VALUE > threshold) {
return false;
}
return ctx.sampleHttpClientRequest(requestId);
}

@Override
public boolean isSampled(final AppSecRequestContext ctx, final long requestId) {
return ctx.isHttpClientRequestSampled(requestId);
}

private long updateRequestCount() {
return globalRequestCount.updateAndGet(cur -> (cur == Long.MAX_VALUE) ? 0L : cur + 1L);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ public interface KnownAddresses {
/** The URL of a network resource being requested (outgoing request) */
Address<String> IO_NET_URL = new Address<>("server.io.net.url");

/** The headers of a network resource being requested (outgoing request) */
Address<Map<String, List<String>>> IO_NET_REQUEST_HEADERS =
new Address<>("server.io.net.request.headers");

/** The method of a network resource being requested (outgoing request) */
Address<String> IO_NET_REQUEST_METHOD = new Address<>("server.io.net.request.method");

/** The body of a network resource being requested (outgoing request) */
Address<Object> IO_NET_REQUEST_BODY = new Address<>("server.io.net.request.body");

/** The status of a network resource being requested (outgoing request) */
Address<Integer> IO_NET_RESPONSE_STATUS = new Address<>("server.io.net.response.status");

/** The response headers of a network resource being requested (outgoing request) */
Address<Map<String, List<String>>> IO_NET_RESPONSE_HEADERS =
new Address<>("server.io.net.response.headers");

/** The response body of a network resource being requested (outgoing request) */
Address<Object> IO_NET_RESPONSE_BODY = new Address<>("server.io.net.response.body");

/** The representation of opened file on the filesystem */
Address<String> IO_FS_FILE = new Address<>("server.io.fs.file");

Expand Down Expand Up @@ -206,6 +226,18 @@ static Address<?> forName(String name) {
return SESSION_ID;
case "server.io.net.url":
return IO_NET_URL;
case "server.io.net.request.headers":
return IO_NET_REQUEST_HEADERS;
case "server.io.net.request.method":
return IO_NET_REQUEST_METHOD;
case "server.io.net.request.body":
return IO_NET_REQUEST_BODY;
case "server.io.net.response.status":
return IO_NET_RESPONSE_STATUS;
case "server.io.net.response.headers":
return IO_NET_RESPONSE_HEADERS;
case "server.io.net.response.body":
return IO_NET_RESPONSE_BODY;
case "server.io.fs.file":
return IO_FS_FILE;
case "server.db.system":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ public class AppSecRequestContext implements DataBundle, Closeable {
private volatile Long apiSecurityEndpointHash;
private volatile byte keepType = PrioritySampling.SAMPLER_KEEP;

private static final AtomicInteger httpClientRequestCount = new AtomicInteger(0);
private static final Set<Long> sampledHttpClientRequests = new HashSet<>();

private static final AtomicIntegerFieldUpdater<AppSecRequestContext> WAF_TIMEOUTS_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(AppSecRequestContext.class, "wafTimeouts");
private static final AtomicIntegerFieldUpdater<AppSecRequestContext> RASP_TIMEOUTS_UPDATER =
Expand Down Expand Up @@ -235,6 +238,29 @@ public void increaseRaspTimeouts() {
RASP_TIMEOUTS_UPDATER.incrementAndGet(this);
}

public void increaseHttpClientRequestCount() {
httpClientRequestCount.incrementAndGet();
}

public boolean sampleHttpClientRequest(final long id) {
synchronized (sampledHttpClientRequests) {
if (sampledHttpClientRequests.size()
< Config.get().getApiSecurityMaxDownstreamRequestBodyAnalysis()) {
sampledHttpClientRequests.add(id);
return true;
}
}
return false;
}

public boolean isHttpClientRequestSampled(final long id) {
return sampledHttpClientRequests.contains(id);
}

public int getHttpClientRequestCount() {
return httpClientRequestCount.get();
}

public int getWafTimeouts() {
return wafTimeouts;
}
Expand Down
Loading
Loading