@@ -264,7 +264,30 @@ func testLogger(t *testing.T) *zerolog.Logger {
264264func TestAddDefaultHeadersWithExistingHeaders (t * testing.T ) {
265265 req := & http.Request {Header : http.Header {"Existing-Header" : {"existing-value" }}}
266266
267- http2 .AddDefaultHeaders (req , http2 .NoRequestId , "" , http .MethodGet )
267+ http2 .AddDefaultHeaders (req , http2 .NoRequestId , "" , http .MethodPost , true )
268+
269+ cacheControl := req .Header .Get ("Cache-Control" )
270+ contentType := req .Header .Get ("Content-Type" )
271+ existingHeader := req .Header .Get ("Existing-Header" )
272+
273+ if cacheControl != "private, max-age=0, no-cache" {
274+ t .Errorf ("Expected Cache-Control header to be 'private, max-age=0, no-cache', got %s" , cacheControl )
275+ }
276+
277+ if contentType != "application/octet-stream" {
278+ t .Errorf ("Expected Content-Type header to be 'application/json', got %s" , contentType )
279+ }
280+
281+ if existingHeader != "existing-value" {
282+ t .Errorf ("Expected Existing-Header to be 'existing-value', got %s" , existingHeader )
283+ }
284+ }
285+
286+ // Test with existing headers
287+ func TestAddDefaultHeadersWithSkipEncodingEnabled (t * testing.T ) {
288+ req := & http.Request {Header : http.Header {"Existing-Header" : {"existing-value" }}}
289+
290+ http2 .AddDefaultHeaders (req , http2 .NoRequestId , "" , http .MethodPost , false )
268291
269292 cacheControl := req .Header .Get ("Cache-Control" )
270293 contentType := req .Header .Get ("Content-Type" )
@@ -350,3 +373,153 @@ func TestAutofixRequestBody(t *testing.T) {
350373
351374 assert .Equal (t , expectedBody , body )
352375}
376+
377+ func TestRunExplain_WithHeaderValidation (t * testing.T ) {
378+ t .Run ("vulnerability explanation with headers" , func (t * testing.T ) {
379+ ruleKey := "test-rule-key"
380+ derivation := "test-derivation"
381+ ruleMessage := "test-rule-message"
382+
383+ expectedResponse := Explanations {
384+ "explanation1" : "This is the first explanation" ,
385+ "explanation2" : "This is the second explanation" ,
386+ }
387+
388+ response := explainResponse {
389+ Status : completeStatus ,
390+ Explanation : expectedResponse ,
391+ }
392+
393+ responseBodyBytes , err := json .Marshal (response )
394+ require .NoError (t , err )
395+
396+ // Create a test server that validates headers
397+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
398+ // Verify headers
399+ assert .Equal (t , "private, max-age=0, no-cache" , r .Header .Get ("Cache-Control" ))
400+ assert .Equal (t , "application/json" , r .Header .Get ("Content-Type" ))
401+ assert .Equal (t , http .MethodPost , r .Method )
402+
403+ // Verify request body
404+ body , readErr := io .ReadAll (r .Body )
405+ require .NoError (t , readErr )
406+
407+ var requestData explainVulnerabilityRequest
408+ err = json .Unmarshal (body , & requestData )
409+ require .NoError (t , err )
410+
411+ assert .Equal (t , ruleKey , requestData .RuleId )
412+ assert .Equal (t , derivation , requestData .Derivation )
413+ assert .Equal (t , ruleMessage , requestData .RuleMessage )
414+ assert .Equal (t , SHORT , requestData .ExplanationLength )
415+
416+ // Send response
417+ w .WriteHeader (http .StatusOK )
418+ _ , _ = w .Write (responseBodyBytes )
419+ }))
420+ defer server .Close ()
421+
422+ // Parse server URL
423+ u , err := url .Parse (server .URL )
424+ require .NoError (t , err )
425+
426+ // Create options
427+ options := ExplainOptions {
428+ RuleKey : ruleKey ,
429+ Derivation : derivation ,
430+ RuleMessage : ruleMessage ,
431+ Endpoint : u ,
432+ }
433+
434+ // Create DeepCodeLLMBinding
435+ d := NewDeepcodeLLMBinding ()
436+
437+ // Run the test
438+ ctx := t .Context ()
439+ ctx = observability .GetContextWithTraceId (ctx , "test-trace-id" )
440+
441+ result , err := d .runExplain (ctx , options )
442+
443+ // Verify results
444+ require .NoError (t , err )
445+ assert .Equal (t , expectedResponse , result )
446+ })
447+
448+ t .Run ("fix explanation with base64 encoded diffs and headers" , func (t * testing.T ) {
449+ ruleKey := "test-rule-key"
450+ testDiffs := []string {
451+ "--- a/file.txt\n +++ b/file.txt\n @@ -1,1 +1,1 @@\n -old line\n +new line\n " ,
452+ }
453+
454+ expectedResponse := Explanations {
455+ "explanation1" : "This explains the fix" ,
456+ }
457+
458+ response := explainResponse {
459+ Status : completeStatus ,
460+ Explanation : expectedResponse ,
461+ }
462+
463+ responseBodyBytes , err := json .Marshal (response )
464+ require .NoError (t , err )
465+
466+ // Create a test server that validates headers and base64 encoded diffs
467+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
468+ // Verify headers
469+ assert .Equal (t , "private, max-age=0, no-cache" , r .Header .Get ("Cache-Control" ))
470+ assert .Equal (t , "application/json" , r .Header .Get ("Content-Type" ))
471+ assert .Equal (t , http .MethodPost , r .Method )
472+
473+ // Verify request body
474+ body , readErr := io .ReadAll (r .Body )
475+ require .NoError (t , readErr )
476+
477+ var requestData explainFixRequest
478+ err = json .Unmarshal (body , & requestData )
479+ require .NoError (t , err )
480+
481+ assert .Equal (t , ruleKey , requestData .RuleId )
482+ assert .Equal (t , SHORT , requestData .ExplanationLength )
483+
484+ // Verify diffs are base64 encoded
485+ require .Len (t , requestData .Diffs , 1 )
486+
487+ // Decode the base64 diff to verify it was encoded properly
488+ decodedDiff , decodeErr := base64 .StdEncoding .DecodeString (requestData .Diffs [0 ])
489+ require .NoError (t , decodeErr )
490+
491+ // The prepareDiffs function strips --- and +++ headers and adds a newline
492+ expectedDecodedDiff := "@@ -1,1 +1,1 @@\n -old line\n +new line\n \n "
493+ assert .Equal (t , expectedDecodedDiff , string (decodedDiff ))
494+
495+ // Send response
496+ w .WriteHeader (http .StatusOK )
497+ _ , _ = w .Write (responseBodyBytes )
498+ }))
499+ defer server .Close ()
500+
501+ // Parse server URL
502+ u , err := url .Parse (server .URL )
503+ require .NoError (t , err )
504+
505+ // Create options
506+ options := ExplainOptions {
507+ RuleKey : ruleKey ,
508+ Diffs : testDiffs ,
509+ Endpoint : u ,
510+ }
511+
512+ // Create DeepCodeLLMBinding
513+ d := NewDeepcodeLLMBinding ()
514+
515+ // Run the test
516+ ctx := t .Context ()
517+ ctx = observability .GetContextWithTraceId (ctx , "test-trace-id" )
518+
519+ result , err := d .runExplain (ctx , options )
520+
521+ // Verify results
522+ require .NoError (t , err )
523+ assert .Equal (t , expectedResponse , result )
524+ })
525+ }
0 commit comments