Fix tests

This commit is contained in:
Srikanth Chekuri
2020-11-21 22:22:57 +05:30
parent 174cabad29
commit 3786d6d0fa
6 changed files with 35 additions and 35 deletions

View File

@ -36,7 +36,7 @@ class TestSqlalchemyInstrumentation(TestBase):
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "sqlite.query") self.assertEqual(spans[0].name, "SELECT 1 + 1;")
def test_not_recording(self): def test_not_recording(self):
mock_tracer = mock.Mock() mock_tracer = mock.Mock()
@ -70,4 +70,4 @@ class TestSqlalchemyInstrumentation(TestBase):
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "sqlite.query") self.assertEqual(spans[0].name, "SELECT 1 + 1;")

View File

@ -20,7 +20,7 @@ from sqlalchemy.orm import sessionmaker
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.instrumentation.sqlalchemy.engine import _DB, _ROWS, _STMT from opentelemetry.instrumentation.sqlalchemy.engine import _DB, _STMT
from opentelemetry.test.test_base import TestBase from opentelemetry.test.test_base import TestBase
Base = declarative_base() Base = declarative_base()
@ -109,9 +109,8 @@ class SQLAlchemyTestMixin(TestBase):
SQLAlchemyInstrumentor().uninstrument() SQLAlchemyInstrumentor().uninstrument()
super().tearDown() super().tearDown()
def _check_span(self, span): def _check_span(self, span, name):
self.assertEqual(span.name, "{}.query".format(self.VENDOR)) self.assertEqual(span.name, name)
self.assertEqual(span.attributes.get("service"), self.SERVICE)
self.assertEqual(span.attributes.get(_DB), self.SQL_DB) self.assertEqual(span.attributes.get(_DB), self.SQL_DB)
self.assertIs(span.status.status_code, trace.status.StatusCode.UNSET) self.assertIs(span.status.status_code, trace.status.StatusCode.UNSET)
self.assertGreater((span.end_time - span.start_time), 0) self.assertGreater((span.end_time - span.start_time), 0)
@ -125,9 +124,13 @@ class SQLAlchemyTestMixin(TestBase):
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
span = spans[0] span = spans[0]
self._check_span(span) stmt = "INSERT INTO players (id, name) VALUES "
if span.attributes.get("db.system") == "sqlite":
stmt += "(?, ?)"
else:
stmt += "(%(id)s, %(name)s)"
self._check_span(span, stmt)
self.assertIn("INSERT INTO players", span.attributes.get(_STMT)) self.assertIn("INSERT INTO players", span.attributes.get(_STMT))
self.assertEqual(span.attributes.get(_ROWS), 1)
self.check_meta(span) self.check_meta(span)
def test_session_query(self): def test_session_query(self):
@ -138,7 +141,12 @@ class SQLAlchemyTestMixin(TestBase):
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
span = spans[0] span = spans[0]
self._check_span(span) stmt = "SELECT players.id AS players_id, players.name AS players_name \nFROM players \nWHERE players.name = "
if span.attributes.get("db.system") == "sqlite":
stmt += "?"
else:
stmt += "%(name_1)s"
self._check_span(span, stmt)
self.assertIn( self.assertIn(
"SELECT players.id AS players_id, players.name AS players_name \nFROM players \nWHERE players.name", "SELECT players.id AS players_id, players.name AS players_name \nFROM players \nWHERE players.name",
span.attributes.get(_STMT), span.attributes.get(_STMT),
@ -147,24 +155,26 @@ class SQLAlchemyTestMixin(TestBase):
def test_engine_connect_execute(self): def test_engine_connect_execute(self):
# ensures that engine.connect() is properly traced # ensures that engine.connect() is properly traced
stmt = "SELECT * FROM players"
with self.connection() as conn: with self.connection() as conn:
rows = conn.execute("SELECT * FROM players").fetchall() rows = conn.execute(stmt).fetchall()
self.assertEqual(len(rows), 0) self.assertEqual(len(rows), 0)
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
span = spans[0] span = spans[0]
self._check_span(span) self._check_span(span, stmt)
self.assertEqual(span.attributes.get(_STMT), "SELECT * FROM players") self.assertEqual(span.attributes.get(_STMT), "SELECT * FROM players")
self.check_meta(span) self.check_meta(span)
def test_parent(self): def test_parent(self):
"""Ensure that sqlalchemy works with opentelemetry.""" """Ensure that sqlalchemy works with opentelemetry."""
stmt = "SELECT * FROM players"
tracer = self.tracer_provider.get_tracer("sqlalch_svc") tracer = self.tracer_provider.get_tracer("sqlalch_svc")
with tracer.start_as_current_span("sqlalch_op"): with tracer.start_as_current_span("sqlalch_op"):
with self.connection() as conn: with self.connection() as conn:
rows = conn.execute("SELECT * FROM players").fetchall() rows = conn.execute(stmt).fetchall()
self.assertEqual(len(rows), 0) self.assertEqual(len(rows), 0)
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
@ -178,5 +188,4 @@ class SQLAlchemyTestMixin(TestBase):
self.assertEqual(parent_span.name, "sqlalch_op") self.assertEqual(parent_span.name, "sqlalch_op")
self.assertEqual(parent_span.instrumentation_info.name, "sqlalch_svc") self.assertEqual(parent_span.instrumentation_info.name, "sqlalch_svc")
self.assertEqual(child_span.name, "{}.query".format(self.VENDOR)) self.assertEqual(child_span.name, stmt)
self.assertEqual(child_span.attributes.get("service"), self.SERVICE)

View File

@ -64,7 +64,6 @@ class SQLAlchemyInstrumentTestCase(TestBase):
self.assertEqual(len(traces), 1) self.assertEqual(len(traces), 1)
span = traces[0] span = traces[0]
# check subset of span fields # check subset of span fields
self.assertEqual(span.name, "postgres.query") self.assertEqual(span.name, "SELECT 1")
self.assertEqual(span.attributes.get("service"), "postgres")
self.assertIs(span.status.status_code, trace.status.StatusCode.UNSET) self.assertIs(span.status.status_code, trace.status.StatusCode.UNSET)
self.assertGreater((span.end_time - span.start_time), 0) self.assertGreater((span.end_time - span.start_time), 0)

View File

@ -23,7 +23,7 @@ from opentelemetry.instrumentation.sqlalchemy.engine import (
_DB, _DB,
_HOST, _HOST,
_PORT, _PORT,
_ROWS, _USER,
_STMT, _STMT,
) )
@ -45,7 +45,6 @@ class MysqlConnectorTestCase(SQLAlchemyTestMixin):
VENDOR = "mysql" VENDOR = "mysql"
SQL_DB = "opentelemetry-tests" SQL_DB = "opentelemetry-tests"
SERVICE = "mysql"
ENGINE_ARGS = { ENGINE_ARGS = {
"url": "mysql+mysqlconnector://%(user)s:%(password)s@%(host)s:%(port)s/%(database)s" "url": "mysql+mysqlconnector://%(user)s:%(password)s@%(host)s:%(port)s/%(database)s"
% MYSQL_CONFIG % MYSQL_CONFIG
@ -55,6 +54,8 @@ class MysqlConnectorTestCase(SQLAlchemyTestMixin):
# check database connection tags # check database connection tags
self.assertEqual(span.attributes.get(_HOST), MYSQL_CONFIG["host"]) self.assertEqual(span.attributes.get(_HOST), MYSQL_CONFIG["host"])
self.assertEqual(span.attributes.get(_PORT), MYSQL_CONFIG["port"]) self.assertEqual(span.attributes.get(_PORT), MYSQL_CONFIG["port"])
self.assertEqual(span.attributes.get(_DB), MYSQL_CONFIG["database"])
self.assertEqual(span.attributes.get(_USER), MYSQL_CONFIG["user"])
def test_engine_execute_errors(self): def test_engine_execute_errors(self):
# ensures that SQL errors are reported # ensures that SQL errors are reported
@ -66,13 +67,11 @@ class MysqlConnectorTestCase(SQLAlchemyTestMixin):
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
span = spans[0] span = spans[0]
# span fields # span fields
self.assertEqual(span.name, "{}.query".format(self.VENDOR)) self.assertEqual(span.name, "SELECT * FROM a_wrong_table")
self.assertEqual(span.attributes.get("service"), self.SERVICE)
self.assertEqual( self.assertEqual(
span.attributes.get(_STMT), "SELECT * FROM a_wrong_table" span.attributes.get(_STMT), "SELECT * FROM a_wrong_table"
) )
self.assertEqual(span.attributes.get(_DB), self.SQL_DB) self.assertEqual(span.attributes.get(_DB), self.SQL_DB)
self.assertIsNone(span.attributes.get(_ROWS))
self.check_meta(span) self.check_meta(span)
self.assertTrue(span.end_time - span.start_time > 0) self.assertTrue(span.end_time - span.start_time > 0)
# check the error # check the error

View File

@ -24,7 +24,6 @@ from opentelemetry.instrumentation.sqlalchemy.engine import (
_DB, _DB,
_HOST, _HOST,
_PORT, _PORT,
_ROWS,
_STMT, _STMT,
) )
@ -44,9 +43,8 @@ class PostgresTestCase(SQLAlchemyTestMixin):
__test__ = True __test__ = True
VENDOR = "postgres" VENDOR = "postgresql"
SQL_DB = "opentelemetry-tests" SQL_DB = "opentelemetry-tests"
SERVICE = "postgres"
ENGINE_ARGS = { ENGINE_ARGS = {
"url": "postgresql://%(user)s:%(password)s@%(host)s:%(port)s/%(dbname)s" "url": "postgresql://%(user)s:%(password)s@%(host)s:%(port)s/%(dbname)s"
% POSTGRES_CONFIG % POSTGRES_CONFIG
@ -67,13 +65,11 @@ class PostgresTestCase(SQLAlchemyTestMixin):
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
span = spans[0] span = spans[0]
# span fields # span fields
self.assertEqual(span.name, "{}.query".format(self.VENDOR)) self.assertEqual(span.name, "SELECT * FROM a_wrong_table")
self.assertEqual(span.attributes.get("service"), self.SERVICE)
self.assertEqual( self.assertEqual(
span.attributes.get(_STMT), "SELECT * FROM a_wrong_table" span.attributes.get(_STMT), "SELECT * FROM a_wrong_table"
) )
self.assertEqual(span.attributes.get(_DB), self.SQL_DB) self.assertEqual(span.attributes.get(_DB), self.SQL_DB)
self.assertIsNone(span.attributes.get(_ROWS))
self.check_meta(span) self.check_meta(span)
self.assertTrue(span.end_time - span.start_time > 0) self.assertTrue(span.end_time - span.start_time > 0)
# check the error # check the error
@ -88,9 +84,8 @@ class PostgresCreatorTestCase(PostgresTestCase):
of `PostgresTestCase`, but it uses a specific `creator` function. of `PostgresTestCase`, but it uses a specific `creator` function.
""" """
VENDOR = "postgres" VENDOR = "postgresql"
SQL_DB = "opentelemetry-tests" SQL_DB = "opentelemetry-tests"
SERVICE = "postgres"
ENGINE_ARGS = { ENGINE_ARGS = {
"url": "postgresql://", "url": "postgresql://",
"creator": lambda: psycopg2.connect(**POSTGRES_CONFIG), "creator": lambda: psycopg2.connect(**POSTGRES_CONFIG),

View File

@ -18,7 +18,7 @@ import pytest
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy.engine import _DB, _ROWS, _STMT from opentelemetry.instrumentation.sqlalchemy.engine import _DB, _STMT
from .mixins import SQLAlchemyTestMixin from .mixins import SQLAlchemyTestMixin
@ -30,26 +30,24 @@ class SQLiteTestCase(SQLAlchemyTestMixin):
VENDOR = "sqlite" VENDOR = "sqlite"
SQL_DB = ":memory:" SQL_DB = ":memory:"
SERVICE = "sqlite"
ENGINE_ARGS = {"url": "sqlite:///:memory:"} ENGINE_ARGS = {"url": "sqlite:///:memory:"}
def test_engine_execute_errors(self): def test_engine_execute_errors(self):
# ensures that SQL errors are reported # ensures that SQL errors are reported
stmt = "SELECT * FROM a_wrong_table"
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
with self.connection() as conn: with self.connection() as conn:
conn.execute("SELECT * FROM a_wrong_table").fetchall() conn.execute(stmt).fetchall()
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1) self.assertEqual(len(spans), 1)
span = spans[0] span = spans[0]
# span fields # span fields
self.assertEqual(span.name, "{}.query".format(self.VENDOR)) self.assertEqual(span.name, stmt)
self.assertEqual(span.attributes.get("service"), self.SERVICE)
self.assertEqual( self.assertEqual(
span.attributes.get(_STMT), "SELECT * FROM a_wrong_table" span.attributes.get(_STMT), "SELECT * FROM a_wrong_table"
) )
self.assertEqual(span.attributes.get(_DB), self.SQL_DB) self.assertEqual(span.attributes.get(_DB), self.SQL_DB)
self.assertIsNone(span.attributes.get(_ROWS))
self.assertTrue((span.end_time - span.start_time) > 0) self.assertTrue((span.end_time - span.start_time) > 0)
# check the error # check the error
self.assertIs( self.assertIs(