Skip to content

Commit 7ba9495

Browse files
authored
Merge e44935f into e8e55b8
2 parents e8e55b8 + e44935f commit 7ba9495

6 files changed

Lines changed: 148 additions & 28 deletions

File tree

internal/extension/extension.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ func (em *ExtensionManager) SendStartInvocationRequest(ctx context.Context, even
109109
if traceId != "" {
110110
ctx = context.WithValue(ctx, DdTraceId, traceId)
111111
}
112-
parentId := response.Header.Get(string(DdParentId))
112+
parentId := traceId
113+
if pid := response.Header.Get(string(DdParentId)); pid != "" {
114+
parentId = pid
115+
}
113116
if parentId != "" {
114117
ctx = context.WithValue(ctx, DdParentId, parentId)
115118
}

internal/extension/extension_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,29 @@ func TestExtensionStartInvokeWithTraceContext(t *testing.T) {
174174
assert.Equal(t, mockSamplingPriority, samplingPriority)
175175
}
176176

177+
func TestExtensionStartInvokeWithTraceContextNoParentID(t *testing.T) {
178+
headers := http.Header{}
179+
headers.Set(string(DdTraceId), mockTraceId)
180+
headers.Set(string(DdSamplingPriority), mockSamplingPriority)
181+
182+
em := &ExtensionManager{
183+
startInvocationUrl: startInvocationUrl,
184+
httpClient: &ClientSuccessStartInvoke{
185+
headers: headers,
186+
},
187+
}
188+
ctx := em.SendStartInvocationRequest(context.TODO(), []byte{})
189+
traceId := ctx.Value(DdTraceId)
190+
parentId := ctx.Value(DdParentId)
191+
samplingPriority := ctx.Value(DdSamplingPriority)
192+
err := em.Flush()
193+
194+
assert.Nil(t, err)
195+
assert.Equal(t, mockTraceId, traceId)
196+
assert.Equal(t, mockTraceId, parentId)
197+
assert.Equal(t, mockSamplingPriority, samplingPriority)
198+
}
199+
177200
func TestExtensionEndInvocation(t *testing.T) {
178201
em := &ExtensionManager{
179202
endInvocationUrl: endInvocationUrl,

internal/trace/context.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"strconv"
1818
"strings"
1919

20+
"github.com/DataDog/datadog-lambda-go/internal/extension"
2021
"github.com/DataDog/datadog-lambda-go/internal/logger"
2122
"github.com/aws/aws-xray-sdk-go/header"
2223
"github.com/aws/aws-xray-sdk-go/xray"
@@ -47,7 +48,7 @@ var DefaultTraceExtractor = getHeadersFromEventHeaders
4748
// contextWithRootTraceContext uses the incoming event and context object payloads to determine
4849
// the root TraceContext and then adds that TraceContext to the context object.
4950
func contextWithRootTraceContext(ctx context.Context, ev json.RawMessage, mergeXrayTraces bool, extractor ContextExtractor) (context.Context, error) {
50-
datadogTraceContext, gotDatadogTraceContext := getTraceContext(extractor(ctx, ev))
51+
datadogTraceContext, gotDatadogTraceContext := getTraceContext(ctx, extractor(ctx, ev))
5152

5253
xrayTraceContext, errGettingXrayContext := convertXrayTraceContextFromLambdaContext(ctx)
5354
if errGettingXrayContext != nil {
@@ -126,21 +127,36 @@ func createDummySubsegmentForXrayConverter(ctx context.Context, traceCtx TraceCo
126127
return nil
127128
}
128129

129-
func getTraceContext(context map[string]string) (TraceContext, bool) {
130+
func getTraceContext(ctx context.Context, headers map[string]string) (TraceContext, bool) {
130131
tc := TraceContext{}
131132

132-
traceID, ok := context[traceIDHeader]
133-
if !ok {
133+
traceID := headers[traceIDHeader]
134+
if traceID == "" {
135+
if val, ok := ctx.Value(extension.DdTraceId).(string); ok {
136+
traceID = val
137+
}
138+
}
139+
if traceID == "" {
134140
return tc, false
135141
}
136142

137-
parentID, ok := context[parentIDHeader]
138-
if !ok {
143+
parentID := headers[parentIDHeader]
144+
if parentID == "" {
145+
if val, ok := ctx.Value(extension.DdParentId).(string); ok {
146+
parentID = val
147+
}
148+
}
149+
if parentID == "" {
139150
return tc, false
140151
}
141152

142-
samplingPriority, ok := context[samplingPriorityHeader]
143-
if !ok {
153+
samplingPriority := headers[samplingPriorityHeader]
154+
if samplingPriority == "" {
155+
if val, ok := ctx.Value(extension.DdSamplingPriority).(string); ok {
156+
samplingPriority = val
157+
}
158+
}
159+
if samplingPriority == "" {
144160
samplingPriority = "1" //sampler-keep
145161
}
146162

internal/trace/context_test.go

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ import (
1414
"io/ioutil"
1515
"testing"
1616

17+
"github.com/DataDog/datadog-lambda-go/internal/extension"
1718
"github.com/aws/aws-xray-sdk-go/header"
18-
1919
"github.com/aws/aws-xray-sdk-go/xray"
20-
2120
"github.com/stretchr/testify/assert"
2221
)
2322

@@ -45,6 +44,20 @@ func mockLambdaXRayTraceContext(ctx context.Context, traceID, parentID string, s
4544
return context.WithValue(ctx, xray.LambdaTraceHeaderKey, headerString)
4645
}
4746

47+
func mockTraceContext(traceID, parentID, samplingPriority string) context.Context {
48+
ctx := context.Background()
49+
if traceID != "" {
50+
ctx = context.WithValue(ctx, extension.DdTraceId, traceID)
51+
}
52+
if parentID != "" {
53+
ctx = context.WithValue(ctx, extension.DdParentId, parentID)
54+
}
55+
if samplingPriority != "" {
56+
ctx = context.WithValue(ctx, extension.DdSamplingPriority, samplingPriority)
57+
}
58+
return ctx
59+
}
60+
4861
func loadRawJSON(t *testing.T, filename string) *json.RawMessage {
4962
bytes, err := ioutil.ReadFile(filename)
5063
if err != nil {
@@ -60,7 +73,7 @@ func TestGetDatadogTraceContextForTraceMetadataNonProxyEvent(t *testing.T) {
6073
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
6174
ev := loadRawJSON(t, "../testdata/apig-event-with-headers.json")
6275

63-
headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
76+
headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
6477
assert.True(t, ok)
6578

6679
expected := TraceContext{
@@ -75,7 +88,7 @@ func TestGetDatadogTraceContextForTraceMetadataWithMixedCaseHeaders(t *testing.T
7588
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
7689
ev := loadRawJSON(t, "../testdata/non-proxy-with-mixed-case-headers.json")
7790

78-
headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
91+
headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
7992
assert.True(t, ok)
8093

8194
expected := TraceContext{
@@ -90,7 +103,7 @@ func TestGetDatadogTraceContextForTraceMetadataWithMissingSamplingPriority(t *te
90103
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
91104
ev := loadRawJSON(t, "../testdata/non-proxy-with-missing-sampling-priority.json")
92105

93-
headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
106+
headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
94107
assert.True(t, ok)
95108

96109
expected := TraceContext{
@@ -105,18 +118,75 @@ func TestGetDatadogTraceContextForInvalidData(t *testing.T) {
105118
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
106119
ev := loadRawJSON(t, "../testdata/invalid.json")
107120

108-
_, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
121+
_, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
109122
assert.False(t, ok)
110123
}
111124

112125
func TestGetDatadogTraceContextForMissingData(t *testing.T) {
113126
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
114127
ev := loadRawJSON(t, "../testdata/non-proxy-no-headers.json")
115128

116-
_, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
129+
_, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
117130
assert.False(t, ok)
118131
}
119132

133+
func TestGetDatadogTraceContextFromContextObject(t *testing.T) {
134+
testcases := []struct {
135+
traceID string
136+
parentID string
137+
samplingPriority string
138+
expectTC TraceContext
139+
expectOk bool
140+
}{
141+
{
142+
"trace",
143+
"parent",
144+
"sampling",
145+
TraceContext{
146+
"x-datadog-trace-id": "trace",
147+
"x-datadog-parent-id": "parent",
148+
"x-datadog-sampling-priority": "sampling",
149+
},
150+
true,
151+
},
152+
{
153+
"",
154+
"parent",
155+
"sampling",
156+
TraceContext{},
157+
false,
158+
},
159+
{
160+
"trace",
161+
"",
162+
"sampling",
163+
TraceContext{},
164+
false,
165+
},
166+
{
167+
"trace",
168+
"parent",
169+
"",
170+
TraceContext{
171+
"x-datadog-trace-id": "trace",
172+
"x-datadog-parent-id": "parent",
173+
"x-datadog-sampling-priority": "1",
174+
},
175+
true,
176+
},
177+
}
178+
179+
ev := loadRawJSON(t, "../testdata/non-proxy-no-headers.json")
180+
for _, test := range testcases {
181+
t.Run(test.traceID+test.parentID+test.samplingPriority, func(t *testing.T) {
182+
ctx := mockTraceContext(test.traceID, test.parentID, test.samplingPriority)
183+
tc, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
184+
assert.Equal(t, test.expectTC, tc)
185+
assert.Equal(t, test.expectOk, ok)
186+
})
187+
}
188+
}
189+
120190
func TestConvertXRayTraceID(t *testing.T) {
121191
output, err := convertXRayTraceIDToDatadogTraceID(mockXRayTraceID)
122192
assert.NoError(t, err)

internal/trace/listener.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ func (l *Listener) HandlerStarted(ctx context.Context, msg json.RawMessage) cont
6464
return ctx
6565
}
6666

67+
if l.universalInstrumentation && l.extensionManager.IsExtensionRunning() {
68+
ctx = l.extensionManager.SendStartInvocationRequest(ctx, msg)
69+
}
70+
6771
ctx, _ = contextWithRootTraceContext(ctx, msg, l.mergeXrayTraces, l.traceContextExtractor)
6872

6973
if !tracerInitialized {
@@ -77,15 +81,11 @@ func (l *Listener) HandlerStarted(ctx context.Context, msg json.RawMessage) cont
7781
}
7882

7983
isDdServerlessSpan := l.universalInstrumentation && l.extensionManager.IsExtensionRunning()
80-
functionExecutionSpan = startFunctionExecutionSpan(ctx, l.mergeXrayTraces, isDdServerlessSpan)
84+
functionExecutionSpan, ctx = startFunctionExecutionSpan(ctx, l.mergeXrayTraces, isDdServerlessSpan)
8185

8286
// Add the span to the context so the user can create child spans
8387
ctx = tracer.ContextWithSpan(ctx, functionExecutionSpan)
8488

85-
if l.universalInstrumentation && l.extensionManager.IsExtensionRunning() {
86-
ctx = l.extensionManager.SendStartInvocationRequest(ctx, msg)
87-
}
88-
8989
return ctx
9090
}
9191

@@ -104,7 +104,7 @@ func (l *Listener) HandlerFinished(ctx context.Context, err error) {
104104

105105
// startFunctionExecutionSpan starts a span that represents the current Lambda function execution
106106
// and returns the span so that it can be finished when the function execution is complete
107-
func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdServerlessSpan bool) tracer.Span {
107+
func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdServerlessSpan bool) (tracer.Span, context.Context) {
108108
// Extract information from context
109109
lambdaCtx, _ := lambdacontext.FromContext(ctx)
110110
rootTraceContext, ok := ctx.Value(traceContextKey).(TraceContext)
@@ -149,7 +149,9 @@ func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdS
149149
span.SetTag("_dd.parent_source", "xray")
150150
}
151151

152-
return span
152+
ctx = context.WithValue(ctx, extension.DdSpanId, fmt.Sprint(span.Context().SpanID()))
153+
154+
return span, ctx
153155
}
154156

155157
func separateVersionFromFunctionArn(functionArn string) (arnWithoutVersion string, functionVersion string) {

internal/trace/listener_test.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package trace
1010

1111
import (
1212
"context"
13+
"fmt"
1314
"testing"
1415

1516
"github.com/DataDog/datadog-lambda-go/internal/extension"
@@ -75,7 +76,7 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeEnabled(t *testing.T) {
7576
mt := mocktracer.Start()
7677
defer mt.Stop()
7778

78-
span := startFunctionExecutionSpan(ctx, true, false)
79+
span, ctx := startFunctionExecutionSpan(ctx, true, false)
7980
span.Finish()
8081
finishedSpan := mt.FinishedSpans()[0]
8182

@@ -91,6 +92,7 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeEnabled(t *testing.T) {
9192
assert.Equal(t, "mockfunctionname", finishedSpan.Tag("functionname"))
9293
assert.Equal(t, "serverless", finishedSpan.Tag("span.type"))
9394
assert.Equal(t, "xray", finishedSpan.Tag("_dd.parent_source"))
95+
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
9496
}
9597

9698
func TestStartFunctionExecutionSpanFromXrayWithMergeDisabled(t *testing.T) {
@@ -105,11 +107,12 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeDisabled(t *testing.T) {
105107
mt := mocktracer.Start()
106108
defer mt.Stop()
107109

108-
span := startFunctionExecutionSpan(ctx, false, false)
110+
span, ctx := startFunctionExecutionSpan(ctx, false, false)
109111
span.Finish()
110112
finishedSpan := mt.FinishedSpans()[0]
111113

112114
assert.Equal(t, nil, finishedSpan.Tag("_dd.parent_source"))
115+
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
113116
}
114117

115118
func TestStartFunctionExecutionSpanFromEventWithMergeEnabled(t *testing.T) {
@@ -124,11 +127,12 @@ func TestStartFunctionExecutionSpanFromEventWithMergeEnabled(t *testing.T) {
124127
mt := mocktracer.Start()
125128
defer mt.Stop()
126129

127-
span := startFunctionExecutionSpan(ctx, true, false)
130+
span, ctx := startFunctionExecutionSpan(ctx, true, false)
128131
span.Finish()
129132
finishedSpan := mt.FinishedSpans()[0]
130133

131134
assert.Equal(t, "xray", finishedSpan.Tag("_dd.parent_source"))
135+
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
132136
}
133137

134138
func TestStartFunctionExecutionSpanFromEventWithMergeDisabled(t *testing.T) {
@@ -143,11 +147,12 @@ func TestStartFunctionExecutionSpanFromEventWithMergeDisabled(t *testing.T) {
143147
mt := mocktracer.Start()
144148
defer mt.Stop()
145149

146-
span := startFunctionExecutionSpan(ctx, false, false)
150+
span, ctx := startFunctionExecutionSpan(ctx, false, false)
147151
span.Finish()
148152
finishedSpan := mt.FinishedSpans()[0]
149153

150154
assert.Equal(t, nil, finishedSpan.Tag("_dd.parent_source"))
155+
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
151156
}
152157

153158
func TestStartFunctionExecutionSpanWithExtension(t *testing.T) {
@@ -162,9 +167,10 @@ func TestStartFunctionExecutionSpanWithExtension(t *testing.T) {
162167
mt := mocktracer.Start()
163168
defer mt.Stop()
164169

165-
span := startFunctionExecutionSpan(ctx, false, true)
170+
span, ctx := startFunctionExecutionSpan(ctx, false, true)
166171
span.Finish()
167172
finishedSpan := mt.FinishedSpans()[0]
168173

169174
assert.Equal(t, string(extension.DdSeverlessSpan), finishedSpan.Tag("resource.name"))
175+
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
170176
}

0 commit comments

Comments
 (0)