Skip to content

Commit 943c98b

Browse files
authored
fix: explain requests shouldnt be encoded [IDE-1528] (#133)
1 parent e1a1b1f commit 943c98b

File tree

5 files changed

+190
-17
lines changed

5 files changed

+190
-17
lines changed

http/http.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func NewDefaultClientFactory() HTTPClientFactory {
163163
return clientFunc
164164
}
165165

166-
func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method string) {
166+
func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method string, needsEncoding bool) {
167167
// if requestId is empty it will be enriched from the Gateway
168168
if len(requestId) > 0 {
169169
req.Header.Set("snyk-request-id", requestId)
@@ -175,7 +175,7 @@ func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method
175175
// https://www.keycdn.com/blog/http-cache-headers
176176
req.Header.Set("Cache-Control", "private, max-age=0, no-cache")
177177

178-
if mustBeEncoded(method) {
178+
if mustBeEncoded(method, needsEncoding) {
179179
req.Header.Set("Content-Type", "application/octet-stream")
180180
req.Header.Set("Content-Encoding", "gzip")
181181
} else {
@@ -185,9 +185,9 @@ func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method
185185

186186
// EncodeIfNeeded returns a byte buffer for the requestBody. Depending on the request method, it may encode the buffer.
187187
// (See http.mustBeEncoded for the list of methods which require encoding the request body.)
188-
func EncodeIfNeeded(method string, requestBody []byte) (*bytes.Buffer, error) {
188+
func EncodeIfNeeded(method string, requestBody []byte, needsEncoding bool) (*bytes.Buffer, error) {
189189
b := new(bytes.Buffer)
190-
if mustBeEncoded(method) {
190+
if mustBeEncoded(method, needsEncoding) {
191191
enc := encoding.NewEncoder(b)
192192
_, err := enc.Write(requestBody)
193193
if err != nil {
@@ -200,6 +200,6 @@ func EncodeIfNeeded(method string, requestBody []byte) (*bytes.Buffer, error) {
200200
}
201201

202202
// mustBeEncoded returns true if the request method requires the request body to be encoded.
203-
func mustBeEncoded(method string) bool {
204-
return method == http.MethodPost || method == http.MethodPut
203+
func mustBeEncoded(method string, needsEncoding bool) bool {
204+
return needsEncoding && (method == http.MethodPost || method == http.MethodPut)
205205
}

internal/analysis/analysis_legacy.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (a *analysisOrchestrator) RunLegacyTest(ctx context.Context, bundleHash str
165165
httpMethod := http.MethodPost
166166

167167
// Encode the request body
168-
bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(http.MethodPost, requestBody)
168+
bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(http.MethodPost, requestBody, true)
169169
if err != nil {
170170
a.logger.Err(err).Str("requestBody", string(requestBody)).Msg("error encoding request body")
171171
return nil, scan.LegacyScanStatus{}, err
@@ -176,7 +176,7 @@ func (a *analysisOrchestrator) RunLegacyTest(ctx context.Context, bundleHash str
176176
a.logger.Err(err).Str("method", method).Msg("error creating HTTP request")
177177
return nil, scan.LegacyScanStatus{}, err
178178
}
179-
codeClientHTTP.AddDefaultHeaders(req, span.GetTraceId(), a.config.Organization(), httpMethod)
179+
codeClientHTTP.AddDefaultHeaders(req, span.GetTraceId(), a.config.Organization(), httpMethod, true)
180180

181181
// Make HTTP call
182182
resp, err := a.httpClient.Do(req)

internal/deepcode/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ func (s *deepcodeClient) Request(
215215
return nil, err
216216
}
217217

218-
bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(method, requestBody)
218+
bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(method, requestBody, true)
219219
if err != nil {
220220
return nil, err
221221
}
@@ -225,7 +225,7 @@ func (s *deepcodeClient) Request(
225225
return nil, err
226226
}
227227

228-
codeClientHTTP.AddDefaultHeaders(req, codeClientHTTP.NoRequestId, s.config.Organization(), method)
228+
codeClientHTTP.AddDefaultHeaders(req, codeClientHTTP.NoRequestId, s.config.Organization(), method, true)
229229

230230
response, err := s.httpClient.Do(req)
231231
if err != nil {

llm/api_client.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain
4444
}
4545
}
4646

47-
responseBody, err := d.submitRequest(span.Context(), u, requestBody, "")
47+
responseBody, err := d.submitRequest(span.Context(), u, requestBody, "", false)
4848
if err != nil {
4949
return Explanations{}, err
5050
}
@@ -63,14 +63,14 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain
6363
return explains, nil
6464
}
6565

66-
func (d *DeepCodeLLMBindingImpl) submitRequest(ctx context.Context, url *url.URL, requestBody []byte, orgId string) ([]byte, error) {
66+
func (d *DeepCodeLLMBindingImpl) submitRequest(ctx context.Context, url *url.URL, requestBody []byte, orgId string, needsEncoding bool) ([]byte, error) {
6767
logger := d.logger.With().Str("method", "submitRequest").Logger()
6868
logger.Trace().Str("payload body: %s\n", string(requestBody)).Msg("Marshaled payload")
6969
span := d.instrumentor.StartSpan(ctx, "code.SubmitRequest")
7070
defer span.Finish()
7171

7272
// Encode the request body
73-
bodyBuffer, err := http2.EncodeIfNeeded(http.MethodPost, requestBody)
73+
bodyBuffer, err := http2.EncodeIfNeeded(http.MethodPost, requestBody, needsEncoding)
7474
if err != nil {
7575
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error encoding request body")
7676
return nil, err
@@ -82,7 +82,7 @@ func (d *DeepCodeLLMBindingImpl) submitRequest(ctx context.Context, url *url.URL
8282
return nil, err
8383
}
8484

85-
http2.AddDefaultHeaders(req, http2.NoRequestId, orgId, http.MethodPost)
85+
http2.AddDefaultHeaders(req, http2.NoRequestId, orgId, http.MethodPost, needsEncoding)
8686

8787
resp, err := d.httpClientFunc().Do(req) //nolint:bodyclose // this seems to be a false positive
8888
if err != nil {
@@ -152,7 +152,7 @@ func (d *DeepCodeLLMBindingImpl) runAutofix(ctx context.Context, options Autofix
152152
}
153153

154154
logger.Info().Msg("Started obtaining autofix Response")
155-
responseBody, err := d.submitRequest(span.Context(), endpoint, requestBody, options.CodeRequestContext.Org.PublicId)
155+
responseBody, err := d.submitRequest(span.Context(), endpoint, requestBody, options.CodeRequestContext.Org.PublicId, true)
156156
logger.Info().Msg("Finished obtaining autofix Response")
157157

158158
if err != nil {
@@ -228,7 +228,7 @@ func (d *DeepCodeLLMBindingImpl) submitAutofixFeedback(ctx context.Context, opti
228228
}
229229

230230
logger.Info().Msg("Started obtaining autofix Response")
231-
_, err = d.submitRequest(span.Context(), endpoint, requestBody, options.CodeRequestContext.Org.PublicId)
231+
_, err = d.submitRequest(span.Context(), endpoint, requestBody, options.CodeRequestContext.Org.PublicId, true)
232232
logger.Info().Msg("Finished obtaining autofix Response")
233233

234234
return err

llm/api_client_test.go

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,30 @@ func testLogger(t *testing.T) *zerolog.Logger {
264264
func TestAddDefaultHeadersWithExistingHeaders(t *testing.T) {
265265
req := &http.Request{Header: http.Header{"Existing-Header": {"existing-value"}}}
266266

267-
http2.AddDefaultHeaders(req, http2.NoRequestId, "", http.MethodGet)
267+
http2.AddDefaultHeaders(req, http2.NoRequestId, "", http.MethodPost, true)
268+
269+
cacheControl := req.Header.Get("Cache-Control")
270+
contentType := req.Header.Get("Content-Type")
271+
existingHeader := req.Header.Get("Existing-Header")
272+
273+
if cacheControl != "private, max-age=0, no-cache" {
274+
t.Errorf("Expected Cache-Control header to be 'private, max-age=0, no-cache', got %s", cacheControl)
275+
}
276+
277+
if contentType != "application/octet-stream" {
278+
t.Errorf("Expected Content-Type header to be 'application/json', got %s", contentType)
279+
}
280+
281+
if existingHeader != "existing-value" {
282+
t.Errorf("Expected Existing-Header to be 'existing-value', got %s", existingHeader)
283+
}
284+
}
285+
286+
// Test with existing headers
287+
func TestAddDefaultHeadersWithSkipEncodingEnabled(t *testing.T) {
288+
req := &http.Request{Header: http.Header{"Existing-Header": {"existing-value"}}}
289+
290+
http2.AddDefaultHeaders(req, http2.NoRequestId, "", http.MethodPost, false)
268291

269292
cacheControl := req.Header.Get("Cache-Control")
270293
contentType := req.Header.Get("Content-Type")
@@ -350,3 +373,153 @@ func TestAutofixRequestBody(t *testing.T) {
350373

351374
assert.Equal(t, expectedBody, body)
352375
}
376+
377+
func TestRunExplain_WithHeaderValidation(t *testing.T) {
378+
t.Run("vulnerability explanation with headers", func(t *testing.T) {
379+
ruleKey := "test-rule-key"
380+
derivation := "test-derivation"
381+
ruleMessage := "test-rule-message"
382+
383+
expectedResponse := Explanations{
384+
"explanation1": "This is the first explanation",
385+
"explanation2": "This is the second explanation",
386+
}
387+
388+
response := explainResponse{
389+
Status: completeStatus,
390+
Explanation: expectedResponse,
391+
}
392+
393+
responseBodyBytes, err := json.Marshal(response)
394+
require.NoError(t, err)
395+
396+
// Create a test server that validates headers
397+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
398+
// Verify headers
399+
assert.Equal(t, "private, max-age=0, no-cache", r.Header.Get("Cache-Control"))
400+
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
401+
assert.Equal(t, http.MethodPost, r.Method)
402+
403+
// Verify request body
404+
body, readErr := io.ReadAll(r.Body)
405+
require.NoError(t, readErr)
406+
407+
var requestData explainVulnerabilityRequest
408+
err = json.Unmarshal(body, &requestData)
409+
require.NoError(t, err)
410+
411+
assert.Equal(t, ruleKey, requestData.RuleId)
412+
assert.Equal(t, derivation, requestData.Derivation)
413+
assert.Equal(t, ruleMessage, requestData.RuleMessage)
414+
assert.Equal(t, SHORT, requestData.ExplanationLength)
415+
416+
// Send response
417+
w.WriteHeader(http.StatusOK)
418+
_, _ = w.Write(responseBodyBytes)
419+
}))
420+
defer server.Close()
421+
422+
// Parse server URL
423+
u, err := url.Parse(server.URL)
424+
require.NoError(t, err)
425+
426+
// Create options
427+
options := ExplainOptions{
428+
RuleKey: ruleKey,
429+
Derivation: derivation,
430+
RuleMessage: ruleMessage,
431+
Endpoint: u,
432+
}
433+
434+
// Create DeepCodeLLMBinding
435+
d := NewDeepcodeLLMBinding()
436+
437+
// Run the test
438+
ctx := t.Context()
439+
ctx = observability.GetContextWithTraceId(ctx, "test-trace-id")
440+
441+
result, err := d.runExplain(ctx, options)
442+
443+
// Verify results
444+
require.NoError(t, err)
445+
assert.Equal(t, expectedResponse, result)
446+
})
447+
448+
t.Run("fix explanation with base64 encoded diffs and headers", func(t *testing.T) {
449+
ruleKey := "test-rule-key"
450+
testDiffs := []string{
451+
"--- a/file.txt\n+++ b/file.txt\n@@ -1,1 +1,1 @@\n-old line\n+new line\n",
452+
}
453+
454+
expectedResponse := Explanations{
455+
"explanation1": "This explains the fix",
456+
}
457+
458+
response := explainResponse{
459+
Status: completeStatus,
460+
Explanation: expectedResponse,
461+
}
462+
463+
responseBodyBytes, err := json.Marshal(response)
464+
require.NoError(t, err)
465+
466+
// Create a test server that validates headers and base64 encoded diffs
467+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
468+
// Verify headers
469+
assert.Equal(t, "private, max-age=0, no-cache", r.Header.Get("Cache-Control"))
470+
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
471+
assert.Equal(t, http.MethodPost, r.Method)
472+
473+
// Verify request body
474+
body, readErr := io.ReadAll(r.Body)
475+
require.NoError(t, readErr)
476+
477+
var requestData explainFixRequest
478+
err = json.Unmarshal(body, &requestData)
479+
require.NoError(t, err)
480+
481+
assert.Equal(t, ruleKey, requestData.RuleId)
482+
assert.Equal(t, SHORT, requestData.ExplanationLength)
483+
484+
// Verify diffs are base64 encoded
485+
require.Len(t, requestData.Diffs, 1)
486+
487+
// Decode the base64 diff to verify it was encoded properly
488+
decodedDiff, decodeErr := base64.StdEncoding.DecodeString(requestData.Diffs[0])
489+
require.NoError(t, decodeErr)
490+
491+
// The prepareDiffs function strips --- and +++ headers and adds a newline
492+
expectedDecodedDiff := "@@ -1,1 +1,1 @@\n-old line\n+new line\n\n"
493+
assert.Equal(t, expectedDecodedDiff, string(decodedDiff))
494+
495+
// Send response
496+
w.WriteHeader(http.StatusOK)
497+
_, _ = w.Write(responseBodyBytes)
498+
}))
499+
defer server.Close()
500+
501+
// Parse server URL
502+
u, err := url.Parse(server.URL)
503+
require.NoError(t, err)
504+
505+
// Create options
506+
options := ExplainOptions{
507+
RuleKey: ruleKey,
508+
Diffs: testDiffs,
509+
Endpoint: u,
510+
}
511+
512+
// Create DeepCodeLLMBinding
513+
d := NewDeepcodeLLMBinding()
514+
515+
// Run the test
516+
ctx := t.Context()
517+
ctx = observability.GetContextWithTraceId(ctx, "test-trace-id")
518+
519+
result, err := d.runExplain(ctx, options)
520+
521+
// Verify results
522+
require.NoError(t, err)
523+
assert.Equal(t, expectedResponse, result)
524+
})
525+
}

0 commit comments

Comments
 (0)