mirror of
https://github.com/open-telemetry/opentelemetry-python-contrib.git
synced 2025-07-28 12:43:39 +08:00
Add support for async and streaming responses in the Google GenAI instrumentation (#3298)
* Begin instrumentation of GenAI SDK. * Snapshot current state. * Created minimal tests and got first test to pass. * Added test for span attributes. * Ensure that token counts work. * Add more tests. * Make it easy to turn off instrumentation for streaming and async to allow for rapid iteration. * Add licenses and fill out main README.rst. * Add a changelog file. * Fill out 'requirements.txt' and 'README.rst' for the manual instrumentation example. * Add missing exporter dependency for the manual instrumentation example. * Fill out rest of the zero-code example. * Add minimal tests for async, streaming cases. * Update sync test to use indirection on top of 'client.models.generate_content' to simplify test reuse. * Fix ruff check issues. * Add subproject to top-level project build mechanism. * Simplify invocation of pylint. * Fix 'make test' command and lint issues. * Add '.dev' suffix to version per feedback on pull request #3256 * Fix README.rst files for the examples. * Add specific versions for the examples. * Revamp 'make test' to not require local 'tox.ini' configuration. * Extend separators per review comment. Co-authored-by: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com> * Fix version conflict caused by non-hermetic requirements. * Fix typo on the comment line. * Add test for the use of the 'vertex_ai' system, and improve how this system is determined. * Factor out testing logic to enable sharing with the async code. * Addressed minor lint issues. * Make it clearer that nonstreaming_base is a helper module that is not invoked directly. * Integrate feedback from related pull request #3268. * Update workflows with 'tox -e generate-workflows'. * Improve data model and add some rudimentary type checking. * Accept only 'true' for a true value to align with other code. * Update the scope name used. * Add **kwargs to patched methods to prevent future breakage due to the addition of future keyword arguments. * Remove redundant list conversion in call to "sorted". Co-authored-by: Aaron Abbott <aaronabbott@google.com> * Reformat with 'tox -e ruff'. * Fix failing lint workflow. * Fix failing lint workflow. * Exclude Google GenAI instrumentation from the bootstrap code for now. * Minor improvements to the tooling shell files. * Fix typo flagged by codespell spellchecker. * Increase alignment with broader repo practices. * Add more TODOs and documentation to clarify the intended work scope. * Remove unneeded accessor from OTelWrapper. * Add more comments to the tests. * Reformat with ruff. * Change 'desireable' to 'desirable' per codespell spellchecker. * Make tests pass without pythonpath * Fix new lint errors showing up after change * Revert "Fix new lint errors showing up after change" This reverts commit 567adc62a706035ad8ac5e29316c7a6f8d4c7909. pylint ignore instead * Add TODO item required/requested from code review. Co-authored-by: Aaron Abbott <aaronabbott@google.com> * Simplify changelog per PR feedback. * Remove square brackets from model name in span name per PR feedback. * Checkpoint current state. * Misc test cleanup. Now that scripts are invoked solely through pytest via tox, remove main functions and hash bang lines. * Improve quality of event logging. * Implement streaming support in RequestsMocker, get tests passing again. * Add test with multiple responses. * Remove support for async and streaming from TODOs, since this is now addressed. * Increase testing coverage for streaming. * Reformat with ruff. * Add minor version bump with changelog. * Change TODOs to bulleted list. * Update per PR feedback Co-authored-by: Aaron Abbott <aaronabbott@google.com> * Restructure streaming async logic to begin execution earlier. * Reformat with ruff. * Disable pylint check for catching broad exception. Should be allowed given exception is re-raised. * Simplify async streaming solution per PR comment. --------- Co-authored-by: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com> Co-authored-by: Aaron Abbott <aaronabbott@google.com>
This commit is contained in:
@ -7,5 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## Unreleased
|
||||
|
||||
- Add support for async and streaming.
|
||||
([#3298](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3298))
|
||||
|
||||
Create an initial version of Open Telemetry instrumentation for github.com/googleapis/python-genai.
|
||||
([#3256](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3256))
|
@ -4,18 +4,17 @@
|
||||
|
||||
Here are some TODO items required to achieve stability for this package:
|
||||
|
||||
1. Add support for streaming interfaces
|
||||
2. Add support for async interfaces
|
||||
3. Add more span-level attributes for request configuration
|
||||
4. Add more span-level attributes for response information
|
||||
5. Verify and correct formatting of events:
|
||||
- Add more span-level attributes for request configuration
|
||||
- Add more span-level attributes for response information
|
||||
- Verify and correct formatting of events:
|
||||
- Including the 'role' field for message events
|
||||
- Including tool invocation information
|
||||
6. Emit events for safety ratings when they block responses
|
||||
7. Additional cleanup/improvement tasks such as:
|
||||
- Emit events for safety ratings when they block responses
|
||||
- Additional cleanup/improvement tasks such as:
|
||||
- Adoption of 'wrapt' instead of 'functools.wraps'
|
||||
- Bolstering test coverage
|
||||
8. Migrate tests to use VCR.py
|
||||
- Migrate tests to use VCR.py
|
||||
|
||||
## Future
|
||||
|
||||
Beyond the above TODOs, it would also be desirable to extend the
|
||||
|
@ -45,15 +45,11 @@ from .otel_wrapper import OTelWrapper
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Constant used for the value of 'gen_ai.operation.name".
|
||||
_GENERATE_CONTENT_OP_NAME = "generate_content"
|
||||
|
||||
# Constant used to make the absence of content more understandable.
|
||||
_CONTENT_ELIDED = "<elided>"
|
||||
|
||||
# Enable these after these cases are fully vetted and tested
|
||||
_INSTRUMENT_STREAMING = False
|
||||
_INSTRUMENT_ASYNC = False
|
||||
# Constant used for the value of 'gen_ai.operation.name".
|
||||
_GENERATE_CONTENT_OP_NAME = "generate_content"
|
||||
|
||||
|
||||
class _MethodsSnapshot:
|
||||
@ -220,7 +216,9 @@ class _GenerateContentInstrumentationHelper:
|
||||
self._response_index = 0
|
||||
self._candidate_index = 0
|
||||
|
||||
def start_span_as_current_span(self, model_name, function_name):
|
||||
def start_span_as_current_span(
|
||||
self, model_name, function_name, end_on_exit=True
|
||||
):
|
||||
return self._otel_wrapper.start_as_current_span(
|
||||
f"{_GENERATE_CONTENT_OP_NAME} {model_name}",
|
||||
start_time=self._start_time,
|
||||
@ -230,6 +228,7 @@ class _GenerateContentInstrumentationHelper:
|
||||
gen_ai_attributes.GEN_AI_REQUEST_MODEL: self._genai_request_model,
|
||||
gen_ai_attributes.GEN_AI_OPERATION_NAME: _GENERATE_CONTENT_OP_NAME,
|
||||
},
|
||||
end_on_exit=end_on_exit,
|
||||
)
|
||||
|
||||
def process_request(
|
||||
@ -543,9 +542,6 @@ def _create_instrumented_generate_content_stream(
|
||||
snapshot: _MethodsSnapshot, otel_wrapper: OTelWrapper
|
||||
):
|
||||
wrapped_func = snapshot.generate_content_stream
|
||||
if not _INSTRUMENT_STREAMING:
|
||||
# TODO: remove once this case has been fully tested
|
||||
return wrapped_func
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
def instrumented_generate_content_stream(
|
||||
@ -586,9 +582,6 @@ def _create_instrumented_async_generate_content(
|
||||
snapshot: _MethodsSnapshot, otel_wrapper: OTelWrapper
|
||||
):
|
||||
wrapped_func = snapshot.async_generate_content
|
||||
if not _INSTRUMENT_ASYNC:
|
||||
# TODO: remove once this case has been fully tested
|
||||
return wrapped_func
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
async def instrumented_generate_content(
|
||||
@ -630,9 +623,6 @@ def _create_instrumented_async_generate_content_stream( # pyright: ignore
|
||||
snapshot: _MethodsSnapshot, otel_wrapper: OTelWrapper
|
||||
):
|
||||
wrapped_func = snapshot.async_generate_content_stream
|
||||
if not _INSTRUMENT_ASYNC or not _INSTRUMENT_STREAMING:
|
||||
# TODO: remove once this case has been fully tested
|
||||
return wrapped_func
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
async def instrumented_generate_content_stream(
|
||||
@ -647,25 +637,39 @@ def _create_instrumented_async_generate_content_stream( # pyright: ignore
|
||||
self, otel_wrapper, model
|
||||
)
|
||||
with helper.start_span_as_current_span(
|
||||
model, "google.genai.AsyncModels.generate_content_stream"
|
||||
):
|
||||
model,
|
||||
"google.genai.AsyncModels.generate_content_stream",
|
||||
end_on_exit=False,
|
||||
) as span:
|
||||
helper.process_request(contents, config)
|
||||
try:
|
||||
async for response in await wrapped_func(
|
||||
response_async_generator = await wrapped_func(
|
||||
self,
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
**kwargs,
|
||||
): # pyright: ignore
|
||||
)
|
||||
except Exception as error: # pylint: disable=broad-exception-caught
|
||||
helper.process_error(error)
|
||||
helper.finalize_processing()
|
||||
with trace.use_span(span, end_on_exit=True):
|
||||
raise
|
||||
|
||||
async def _response_async_generator_wrapper():
|
||||
with trace.use_span(span, end_on_exit=True):
|
||||
try:
|
||||
async for response in response_async_generator:
|
||||
helper.process_response(response)
|
||||
yield response # pyright: ignore
|
||||
yield response
|
||||
except Exception as error:
|
||||
helper.process_error(error)
|
||||
raise
|
||||
finally:
|
||||
helper.finalize_processing()
|
||||
|
||||
return _response_async_generator_wrapper()
|
||||
|
||||
return instrumented_generate_content_stream
|
||||
|
||||
|
||||
|
@ -17,4 +17,4 @@
|
||||
# This version should stay below "1.0" until the fundamentals
|
||||
# in "TODOS.md" have been addressed. Please revisit the TODOs
|
||||
# listed there before bumping to a stable version.
|
||||
__version__ = "0.0.1.dev"
|
||||
__version__ = "0.0.2.dev"
|
||||
|
@ -179,6 +179,16 @@ class OTelMocker:
|
||||
return event
|
||||
return None
|
||||
|
||||
def get_events_named(self, event_name):
|
||||
result = []
|
||||
for event in self.get_finished_logs():
|
||||
event_name_attr = event.attributes.get("event.name")
|
||||
if event_name_attr is None:
|
||||
continue
|
||||
if event_name_attr == event_name:
|
||||
result.append(event)
|
||||
return result
|
||||
|
||||
def assert_has_event_named(self, name):
|
||||
event = self.get_event_named(name)
|
||||
finished_logs = self.get_finished_logs()
|
||||
|
@ -37,6 +37,7 @@ import functools
|
||||
import http.client
|
||||
import io
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import requests.sessions
|
||||
@ -81,7 +82,7 @@ class RequestsCall:
|
||||
|
||||
|
||||
def _return_error_status(
|
||||
args: RequestsCallArgs, status_code: int, reason: str = None
|
||||
args: RequestsCallArgs, status_code: int, reason: Optional[str] = None
|
||||
):
|
||||
result = requests.Response()
|
||||
result.url = args.request.url
|
||||
@ -123,6 +124,35 @@ def _to_response_generator(response):
|
||||
raise ValueError(f"Unsupported response type: {type(response)}")
|
||||
|
||||
|
||||
def _to_stream_response_generator(response_generators):
|
||||
if len(response_generators) == 1:
|
||||
return response_generators[0]
|
||||
|
||||
def combined_generator(args):
|
||||
first_response = response_generators[0](args)
|
||||
if first_response.status_code != 200:
|
||||
return first_response
|
||||
result = requests.Response()
|
||||
result.status_code = 200
|
||||
result.headers["content-type"] = "application/json"
|
||||
result.encoding = "utf-8"
|
||||
result.headers["transfer-encoding"] = "chunked"
|
||||
contents = []
|
||||
for generator in response_generators:
|
||||
response = generator(args)
|
||||
if response.status_code != 200:
|
||||
continue
|
||||
response_json = response.json()
|
||||
response_json_str = json.dumps(response_json)
|
||||
contents.append(f"data: {response_json_str}")
|
||||
contents_str = "\r\n".join(contents)
|
||||
full_contents = f"{contents_str}\r\n\r\n"
|
||||
result.raw = io.BytesIO(full_contents.encode())
|
||||
return result
|
||||
|
||||
return combined_generator
|
||||
|
||||
|
||||
class RequestsMocker:
|
||||
def __init__(self):
|
||||
self._original_send = requests.sessions.Session.send
|
||||
@ -159,6 +189,38 @@ class RequestsMocker:
|
||||
session: requests.sessions.Session,
|
||||
request: requests.PreparedRequest,
|
||||
**kwargs,
|
||||
):
|
||||
stream = kwargs.get("stream", False)
|
||||
if not stream:
|
||||
return self._do_send_non_streaming(session, request, **kwargs)
|
||||
return self._do_send_streaming(session, request, **kwargs)
|
||||
|
||||
def _do_send_streaming(
|
||||
self,
|
||||
session: requests.sessions.Session,
|
||||
request: requests.PreparedRequest,
|
||||
**kwargs,
|
||||
):
|
||||
args = RequestsCallArgs(session, request, **kwargs)
|
||||
response_generators = []
|
||||
for matcher, response_generator in self._handlers:
|
||||
if matcher is None:
|
||||
response_generators.append(response_generator)
|
||||
elif matcher(args):
|
||||
response_generators.append(response_generator)
|
||||
if not response_generators:
|
||||
response_generators.append(_return_404)
|
||||
response_generator = _to_stream_response_generator(response_generators)
|
||||
call = RequestsCall(args, response_generator)
|
||||
result = call.response
|
||||
self._calls.append(call)
|
||||
return result
|
||||
|
||||
def _do_send_non_streaming(
|
||||
self,
|
||||
session: requests.sessions.Session,
|
||||
request: requests.PreparedRequest,
|
||||
**kwargs,
|
||||
):
|
||||
args = RequestsCallArgs(session, request, **kwargs)
|
||||
response_generator = self._lookup_response_generator(args)
|
||||
|
@ -17,31 +17,7 @@ import os
|
||||
import unittest
|
||||
|
||||
from ..common.base import TestCase
|
||||
|
||||
|
||||
def create_valid_response(
|
||||
response_text="The model response", input_tokens=10, output_tokens=20
|
||||
):
|
||||
return {
|
||||
"modelVersion": "gemini-2.0-flash-test123",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": input_tokens,
|
||||
"candidatesTokenCount": output_tokens,
|
||||
"totalTokenCount": input_tokens + output_tokens,
|
||||
},
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": response_text,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
from .util import create_valid_response
|
||||
|
||||
|
||||
class NonStreamingTestCase(TestCase):
|
||||
@ -56,22 +32,12 @@ class NonStreamingTestCase(TestCase):
|
||||
def generate_content(self, *args, **kwargs):
|
||||
raise NotImplementedError("Must implement 'generate_content'.")
|
||||
|
||||
@property
|
||||
def expected_function_name(self):
|
||||
raise NotImplementedError("Must implement 'expected_function_name'.")
|
||||
|
||||
def configure_valid_response(
|
||||
self,
|
||||
response_text="The model_response",
|
||||
input_tokens=10,
|
||||
output_tokens=20,
|
||||
):
|
||||
self.requests.add_response(
|
||||
create_valid_response(
|
||||
response_text=response_text,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
)
|
||||
def configure_valid_response(self, *args, **kwargs):
|
||||
self.requests.add_response(create_valid_response(*args, **kwargs))
|
||||
|
||||
def test_instrumentation_does_not_break_core_functionality(self):
|
||||
self.configure_valid_response(response_text="Yep, it works!")
|
||||
|
@ -0,0 +1,76 @@
|
||||
# Copyright The OpenTelemetry Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from ..common.base import TestCase
|
||||
from .util import create_valid_response
|
||||
|
||||
|
||||
class StreamingTestCase(TestCase):
|
||||
# The "setUp" function is defined by "unittest.TestCase" and thus
|
||||
# this name must be used. Uncertain why pylint doesn't seem to
|
||||
# recognize that this is a unit test class for which this is inherited.
|
||||
def setUp(self): # pylint: disable=invalid-name
|
||||
super().setUp()
|
||||
if self.__class__ == StreamingTestCase:
|
||||
raise unittest.SkipTest("Skipping testcase base.")
|
||||
|
||||
def generate_content(self, *args, **kwargs):
|
||||
raise NotImplementedError("Must implement 'generate_content'.")
|
||||
|
||||
@property
|
||||
def expected_function_name(self):
|
||||
raise NotImplementedError("Must implement 'expected_function_name'.")
|
||||
|
||||
def configure_valid_response(self, *args, **kwargs):
|
||||
self.requests.add_response(create_valid_response(*args, **kwargs))
|
||||
|
||||
def test_instrumentation_does_not_break_core_functionality(self):
|
||||
self.configure_valid_response(response_text="Yep, it works!")
|
||||
responses = self.generate_content(
|
||||
model="gemini-2.0-flash", contents="Does this work?"
|
||||
)
|
||||
self.assertEqual(len(responses), 1)
|
||||
response = responses[0]
|
||||
self.assertEqual(response.text, "Yep, it works!")
|
||||
|
||||
def test_handles_multiple_ressponses(self):
|
||||
self.configure_valid_response(response_text="First response")
|
||||
self.configure_valid_response(response_text="Second response")
|
||||
responses = self.generate_content(
|
||||
model="gemini-2.0-flash", contents="Does this work?"
|
||||
)
|
||||
self.assertEqual(len(responses), 2)
|
||||
self.assertEqual(responses[0].text, "First response")
|
||||
self.assertEqual(responses[1].text, "Second response")
|
||||
choice_events = self.otel.get_events_named("gen_ai.choice")
|
||||
self.assertEqual(len(choice_events), 2)
|
||||
|
||||
def test_includes_token_counts_in_span_aggregated_from_responses(self):
|
||||
# Configure multiple responses whose input/output tokens should be
|
||||
# accumulated together when summarizing the end-to-end request.
|
||||
#
|
||||
# Input: 1 + 3 + 5 => 4 + 5 => 9
|
||||
# Output: 2 + 4 + 6 => 6 + 6 => 12
|
||||
self.configure_valid_response(input_tokens=1, output_tokens=2)
|
||||
self.configure_valid_response(input_tokens=3, output_tokens=4)
|
||||
self.configure_valid_response(input_tokens=5, output_tokens=6)
|
||||
|
||||
self.generate_content(model="gemini-2.0-flash", contents="Some input")
|
||||
|
||||
self.otel.assert_has_span_named("generate_content gemini-2.0-flash")
|
||||
span = self.otel.get_span_named("generate_content gemini-2.0-flash")
|
||||
self.assertEqual(span.attributes["gen_ai.usage.input_tokens"], 9)
|
||||
self.assertEqual(span.attributes["gen_ai.usage.output_tokens"], 12)
|
@ -12,65 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# TODO: Once the async non-streaming case has been fully implemented,
|
||||
# reimplement this in terms of "nonstreaming_base.py".
|
||||
|
||||
import asyncio
|
||||
|
||||
from ..common.base import TestCase
|
||||
from .nonstreaming_base import NonStreamingTestCase
|
||||
|
||||
|
||||
def create_valid_response(
|
||||
response_text="The model response", input_tokens=10, output_tokens=20
|
||||
):
|
||||
return {
|
||||
"modelVersion": "gemini-2.0-flash-test123",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": input_tokens,
|
||||
"candidatesTokenCount": output_tokens,
|
||||
"totalTokenCount": input_tokens + output_tokens,
|
||||
},
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": response_text,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Temporary test fixture just to ensure that the in-progress work to
|
||||
# implement this case doesn't break the original code.
|
||||
class TestGenerateContentAsyncNonstreaming(TestCase):
|
||||
def configure_valid_response(
|
||||
self,
|
||||
response_text="The model_response",
|
||||
input_tokens=10,
|
||||
output_tokens=20,
|
||||
):
|
||||
self.requests.add_response(
|
||||
create_valid_response(
|
||||
response_text=response_text,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
class TestGenerateContentAsyncNonstreaming(NonStreamingTestCase):
|
||||
def generate_content(self, *args, **kwargs):
|
||||
return asyncio.run(
|
||||
self.client.aio.models.generate_content(*args, **kwargs) # pylint: disable=missing-kwoa
|
||||
)
|
||||
|
||||
def test_async_generate_content_not_broken_by_instrumentation(self):
|
||||
self.configure_valid_response(response_text="Yep, it works!")
|
||||
response = self.generate_content(
|
||||
model="gemini-2.0-flash", contents="Does this work?"
|
||||
)
|
||||
self.assertEqual(response.text, "Yep, it works!")
|
||||
@property
|
||||
def expected_function_name(self):
|
||||
return "google.genai.AsyncModels.generate_content"
|
||||
|
@ -12,58 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: once the async streaming case has been implemented, we should have
|
||||
# two different tests here that inherit from "streaming_base" and "nonstreaming_base",
|
||||
# covering the cases of one response and multiple streaming responses.
|
||||
|
||||
import asyncio
|
||||
|
||||
from ..common.base import TestCase
|
||||
from .nonstreaming_base import NonStreamingTestCase
|
||||
from .streaming_base import StreamingTestCase
|
||||
|
||||
|
||||
def create_valid_response(
|
||||
response_text="The model response", input_tokens=10, output_tokens=20
|
||||
):
|
||||
return {
|
||||
"modelVersion": "gemini-2.0-flash-test123",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": input_tokens,
|
||||
"candidatesTokenCount": output_tokens,
|
||||
"totalTokenCount": input_tokens + output_tokens,
|
||||
},
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": response_text,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
class AsyncStreamingMixin:
|
||||
@property
|
||||
def expected_function_name(self):
|
||||
return "google.genai.AsyncModels.generate_content_stream"
|
||||
|
||||
|
||||
# Temporary test fixture just to ensure that the in-progress work to
|
||||
# implement this case doesn't break the original code.
|
||||
class TestGenerateContentAsyncStreaming(TestCase):
|
||||
def configure_valid_response(
|
||||
self,
|
||||
response_text="The model_response",
|
||||
input_tokens=10,
|
||||
output_tokens=20,
|
||||
):
|
||||
self.requests.add_response(
|
||||
create_valid_response(
|
||||
response_text=response_text,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
async def _generate_content_helper(self, *args, **kwargs):
|
||||
async def _generate_content_stream_helper(self, *args, **kwargs):
|
||||
result = []
|
||||
async for (
|
||||
response
|
||||
@ -73,13 +33,23 @@ class TestGenerateContentAsyncStreaming(TestCase):
|
||||
result.append(response)
|
||||
return result
|
||||
|
||||
def generate_content(self, *args, **kwargs):
|
||||
return asyncio.run(self._generate_content_helper(*args, **kwargs))
|
||||
|
||||
def test_async_generate_content_not_broken_by_instrumentation(self):
|
||||
self.configure_valid_response(response_text="Yep, it works!")
|
||||
responses = self.generate_content(
|
||||
model="gemini-2.0-flash", contents="Does this work?"
|
||||
def generate_content_stream(self, *args, **kwargs):
|
||||
return asyncio.run(
|
||||
self._generate_content_stream_helper(*args, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateContentAsyncStreamingWithSingleResult(
|
||||
AsyncStreamingMixin, NonStreamingTestCase
|
||||
):
|
||||
def generate_content(self, *args, **kwargs):
|
||||
responses = self.generate_content_stream(*args, **kwargs)
|
||||
self.assertEqual(len(responses), 1)
|
||||
self.assertEqual(responses[0].text, "Yep, it works!")
|
||||
return responses[0]
|
||||
|
||||
|
||||
class TestGenerateContentAsyncStreamingWithStreamedResults(
|
||||
AsyncStreamingMixin, StreamingTestCase
|
||||
):
|
||||
def generate_content(self, *args, **kwargs):
|
||||
return self.generate_content_stream(*args, **kwargs)
|
||||
|
@ -12,57 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: once the async streaming case has been implemented, we should have
|
||||
# two different tests here that inherit from "streaming_base" and "nonstreaming_base",
|
||||
# covering the cases of one response and multiple streaming responses.
|
||||
|
||||
from .nonstreaming_base import NonStreamingTestCase
|
||||
from .streaming_base import StreamingTestCase
|
||||
|
||||
|
||||
from ..common.base import TestCase
|
||||
class StreamingMixin:
|
||||
@property
|
||||
def expected_function_name(self):
|
||||
return "google.genai.Models.generate_content_stream"
|
||||
|
||||
|
||||
def create_valid_response(
|
||||
response_text="The model response", input_tokens=10, output_tokens=20
|
||||
):
|
||||
return {
|
||||
"modelVersion": "gemini-2.0-flash-test123",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": input_tokens,
|
||||
"candidatesTokenCount": output_tokens,
|
||||
"totalTokenCount": input_tokens + output_tokens,
|
||||
},
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": response_text,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Temporary test fixture just to ensure that the in-progress work to
|
||||
# implement this case doesn't break the original code.
|
||||
class TestGenerateContentSyncStreaming(TestCase):
|
||||
def configure_valid_response(
|
||||
self,
|
||||
response_text="The model_response",
|
||||
input_tokens=10,
|
||||
output_tokens=20,
|
||||
):
|
||||
self.requests.add_response(
|
||||
create_valid_response(
|
||||
response_text=response_text,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
def generate_content(self, *args, **kwargs):
|
||||
def generate_content_stream(self, *args, **kwargs):
|
||||
result = []
|
||||
for response in self.client.models.generate_content_stream( # pylint: disable=missing-kwoa
|
||||
*args, **kwargs
|
||||
@ -70,10 +30,18 @@ class TestGenerateContentSyncStreaming(TestCase):
|
||||
result.append(response)
|
||||
return result
|
||||
|
||||
def test_async_generate_content_not_broken_by_instrumentation(self):
|
||||
self.configure_valid_response(response_text="Yep, it works!")
|
||||
responses = self.generate_content(
|
||||
model="gemini-2.0-flash", contents="Does this work?"
|
||||
)
|
||||
|
||||
class TestGenerateContentStreamingWithSingleResult(
|
||||
StreamingMixin, NonStreamingTestCase
|
||||
):
|
||||
def generate_content(self, *args, **kwargs):
|
||||
responses = self.generate_content_stream(*args, **kwargs)
|
||||
self.assertEqual(len(responses), 1)
|
||||
self.assertEqual(responses[0].text, "Yep, it works!")
|
||||
return responses[0]
|
||||
|
||||
|
||||
class TestGenerateContentStreamingWithStreamedResults(
|
||||
StreamingMixin, StreamingTestCase
|
||||
):
|
||||
def generate_content(self, *args, **kwargs):
|
||||
return self.generate_content_stream(*args, **kwargs)
|
||||
|
@ -0,0 +1,38 @@
|
||||
# Copyright The OpenTelemetry Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def create_valid_response(
|
||||
response_text="The model response", input_tokens=10, output_tokens=20
|
||||
):
|
||||
return {
|
||||
"modelVersion": "gemini-2.0-flash-test123",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": input_tokens,
|
||||
"candidatesTokenCount": output_tokens,
|
||||
"totalTokenCount": input_tokens + output_tokens,
|
||||
},
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": response_text,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
Reference in New Issue
Block a user