@@ -341,6 +341,255 @@ async def __aexit__(self, *args):
341341 assert connection is mock_connection
342342
343343
344+ @pytest .mark .asyncio
345+ async def test_generate_content_async_with_custom_headers (
346+ gemini_llm , llm_request , generate_content_response
347+ ):
348+ """Test that tracking headers are updated when custom headers are provided."""
349+ # Add custom headers to the request config
350+ custom_headers = {"custom-header" : "custom-value" }
351+ for key in gemini_llm ._tracking_headers :
352+ custom_headers [key ] = "custom " + gemini_llm ._tracking_headers [key ]
353+ llm_request .config .http_options = types .HttpOptions (headers = custom_headers )
354+
355+ with mock .patch .object (gemini_llm , "api_client" ) as mock_client :
356+ # Create a mock coroutine that returns the generate_content_response
357+ async def mock_coro ():
358+ return generate_content_response
359+
360+ mock_client .aio .models .generate_content .return_value = mock_coro ()
361+
362+ responses = [
363+ resp
364+ async for resp in gemini_llm .generate_content_async (
365+ llm_request , stream = False
366+ )
367+ ]
368+
369+ # Verify that the config passed to generate_content contains merged headers
370+ mock_client .aio .models .generate_content .assert_called_once ()
371+ call_args = mock_client .aio .models .generate_content .call_args
372+ config_arg = call_args .kwargs ["config" ]
373+
374+ for key , value in config_arg .http_options .headers .items ():
375+ if key in gemini_llm ._tracking_headers :
376+ assert value == gemini_llm ._tracking_headers [key ]
377+ else :
378+ assert value == custom_headers [key ]
379+
380+ assert len (responses ) == 1
381+ assert isinstance (responses [0 ], LlmResponse )
382+
383+
384+ @pytest .mark .asyncio
385+ async def test_generate_content_async_stream_with_custom_headers (
386+ gemini_llm , llm_request
387+ ):
388+ """Test that tracking headers are updated when custom headers are provided in streaming mode."""
389+ # Add custom headers to the request config
390+ custom_headers = {"custom-header" : "custom-value" }
391+ llm_request .config .http_options = types .HttpOptions (headers = custom_headers )
392+
393+ with mock .patch .object (gemini_llm , "api_client" ) as mock_client :
394+ # Create mock stream responses
395+ class MockAsyncIterator :
396+
397+ def __init__ (self , seq ):
398+ self .iter = iter (seq )
399+
400+ def __aiter__ (self ):
401+ return self
402+
403+ async def __anext__ (self ):
404+ try :
405+ return next (self .iter )
406+ except StopIteration :
407+ raise StopAsyncIteration
408+
409+ mock_responses = [
410+ types .GenerateContentResponse (
411+ candidates = [
412+ types .Candidate (
413+ content = Content (
414+ role = "model" , parts = [Part .from_text (text = "Hello" )]
415+ ),
416+ finish_reason = types .FinishReason .STOP ,
417+ )
418+ ]
419+ )
420+ ]
421+
422+ async def mock_coro ():
423+ return MockAsyncIterator (mock_responses )
424+
425+ mock_client .aio .models .generate_content_stream .return_value = mock_coro ()
426+
427+ responses = [
428+ resp
429+ async for resp in gemini_llm .generate_content_async (
430+ llm_request , stream = True
431+ )
432+ ]
433+
434+ # Verify that the config passed to generate_content_stream contains merged headers
435+ mock_client .aio .models .generate_content_stream .assert_called_once ()
436+ call_args = mock_client .aio .models .generate_content_stream .call_args
437+ config_arg = call_args .kwargs ["config" ]
438+
439+ expected_headers = custom_headers .copy ()
440+ expected_headers .update (gemini_llm ._tracking_headers )
441+ assert config_arg .http_options .headers == expected_headers
442+
443+ assert len (responses ) == 2
444+
445+
446+ @pytest .mark .asyncio
447+ async def test_generate_content_async_without_custom_headers (
448+ gemini_llm , llm_request , generate_content_response
449+ ):
450+ """Test that tracking headers are not modified when no custom headers exist."""
451+ # Ensure no http_options exist initially
452+ llm_request .config .http_options = None
453+
454+ with mock .patch .object (gemini_llm , "api_client" ) as mock_client :
455+
456+ async def mock_coro ():
457+ return generate_content_response
458+
459+ mock_client .aio .models .generate_content .return_value = mock_coro ()
460+
461+ responses = [
462+ resp
463+ async for resp in gemini_llm .generate_content_async (
464+ llm_request , stream = False
465+ )
466+ ]
467+
468+ # Verify that the config passed to generate_content has no http_options
469+ mock_client .aio .models .generate_content .assert_called_once ()
470+ call_args = mock_client .aio .models .generate_content .call_args
471+ config_arg = call_args .kwargs ["config" ]
472+ assert config_arg .http_options is None
473+
474+ assert len (responses ) == 1
475+
476+
477+ def test_live_api_version_vertex_ai (gemini_llm ):
478+ """Test that _live_api_version returns 'v1beta1' for Vertex AI backend."""
479+ with mock .patch .object (
480+ gemini_llm , "_api_backend" , GoogleLLMVariant .VERTEX_AI
481+ ):
482+ assert gemini_llm ._live_api_version == "v1beta1"
483+
484+
485+ def test_live_api_version_gemini_api (gemini_llm ):
486+ """Test that _live_api_version returns 'v1alpha' for Gemini API backend."""
487+ with mock .patch .object (
488+ gemini_llm , "_api_backend" , GoogleLLMVariant .GEMINI_API
489+ ):
490+ assert gemini_llm ._live_api_version == "v1alpha"
491+
492+
493+ def test_live_api_client_properties (gemini_llm ):
494+ """Test that _live_api_client is properly configured with tracking headers and API version."""
495+ with mock .patch .object (
496+ gemini_llm , "_api_backend" , GoogleLLMVariant .VERTEX_AI
497+ ):
498+ client = gemini_llm ._live_api_client
499+
500+ # Verify that the client has the correct headers and API version
501+ http_options = client ._api_client ._http_options
502+ assert http_options .api_version == "v1beta1"
503+
504+ # Check that tracking headers are included
505+ tracking_headers = gemini_llm ._tracking_headers
506+ for key , value in tracking_headers .items ():
507+ assert key in http_options .headers
508+ assert value in http_options .headers [key ]
509+
510+
511+ @pytest .mark .asyncio
512+ async def test_connect_with_custom_headers (gemini_llm , llm_request ):
513+ """Test that connect method updates tracking headers and API version when custom headers are provided."""
514+ # Setup request with live connect config and custom headers
515+ custom_headers = {"custom-live-header" : "live-value" }
516+ llm_request .live_connect_config = types .LiveConnectConfig (
517+ http_options = types .HttpOptions (headers = custom_headers )
518+ )
519+
520+ mock_live_session = mock .AsyncMock ()
521+
522+ # Mock the _live_api_client to return a mock client
523+ with mock .patch .object (gemini_llm , "_live_api_client" ) as mock_live_client :
524+ # Create a mock context manager
525+ class MockLiveConnect :
526+
527+ async def __aenter__ (self ):
528+ return mock_live_session
529+
530+ async def __aexit__ (self , * args ):
531+ pass
532+
533+ mock_live_client .aio .live .connect .return_value = MockLiveConnect ()
534+
535+ async with gemini_llm .connect (llm_request ) as connection :
536+ # Verify that the connect method was called with the right config
537+ mock_live_client .aio .live .connect .assert_called_once ()
538+ call_args = mock_live_client .aio .live .connect .call_args
539+ config_arg = call_args .kwargs ["config" ]
540+
541+ # Verify that tracking headers were merged with custom headers
542+ expected_headers = custom_headers .copy ()
543+ expected_headers .update (gemini_llm ._tracking_headers )
544+ assert config_arg .http_options .headers == expected_headers
545+
546+ # Verify that API version was set
547+ assert config_arg .http_options .api_version == gemini_llm ._live_api_version
548+
549+ # Verify that system instruction and tools were set
550+ assert config_arg .system_instruction is not None
551+ assert config_arg .tools == llm_request .config .tools
552+
553+ # Verify connection is properly wrapped
554+ assert isinstance (connection , GeminiLlmConnection )
555+
556+
557+ @pytest .mark .asyncio
558+ async def test_connect_without_custom_headers (gemini_llm , llm_request ):
559+ """Test that connect method works properly when no custom headers are provided."""
560+ # Setup request with live connect config but no custom headers
561+ llm_request .live_connect_config = types .LiveConnectConfig ()
562+
563+ mock_live_session = mock .AsyncMock ()
564+
565+ with mock .patch .object (gemini_llm , "_live_api_client" ) as mock_live_client :
566+
567+ class MockLiveConnect :
568+
569+ async def __aenter__ (self ):
570+ return mock_live_session
571+
572+ async def __aexit__ (self , * args ):
573+ pass
574+
575+ mock_live_client .aio .live .connect .return_value = MockLiveConnect ()
576+
577+ async with gemini_llm .connect (llm_request ) as connection :
578+ # Verify that the connect method was called with the right config
579+ mock_live_client .aio .live .connect .assert_called_once ()
580+ call_args = mock_live_client .aio .live .connect .call_args
581+ config_arg = call_args .kwargs ["config" ]
582+
583+ # Verify that http_options remains None since no custom headers were provided
584+ assert config_arg .http_options is None
585+
586+ # Verify that system instruction and tools were still set
587+ assert config_arg .system_instruction is not None
588+ assert config_arg .tools == llm_request .config .tools
589+
590+ assert isinstance (connection , GeminiLlmConnection )
591+
592+
344593@pytest .mark .parametrize (
345594 (
346595 "api_backend, "
0 commit comments