mirror of
https://github.com/open-telemetry/opentelemetry-python-contrib.git
synced 2025-08-02 11:31:52 +08:00
Add SQLAlchemy multithreading test (#468)
This commit is contained in:
@ -13,10 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from sqlalchemy import Column, Integer, String, create_engine
|
||||
from sqlalchemy import Column, Integer, String, create_engine, insert
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import close_all_sessions, scoped_session, sessionmaker
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
@ -199,3 +201,45 @@ class SQLAlchemyTestMixin(TestBase):
|
||||
self.assertEqual(parent_span.instrumentation_info.name, "sqlalch_svc")
|
||||
|
||||
self.assertEqual(child_span.name, "SELECT " + self.SQL_DB)
|
||||
|
||||
def test_multithreading(self):
|
||||
"""Ensure spans are captured correctly in a multithreading scenario
|
||||
|
||||
We also expect no logged warnings about calling end() on an ended span.
|
||||
"""
|
||||
|
||||
if self.VENDOR == "sqlite":
|
||||
return
|
||||
|
||||
def insert_player(session):
|
||||
_session = session()
|
||||
player = Player(name="Player")
|
||||
_session.add(player)
|
||||
_session.commit()
|
||||
_session.query(Player).all()
|
||||
|
||||
def insert_players(session):
|
||||
_session = session()
|
||||
players = []
|
||||
for player_number in range(3):
|
||||
players.append(Player(name=f"Player {player_number}"))
|
||||
_session.add_all(players)
|
||||
_session.commit()
|
||||
|
||||
session_factory = sessionmaker(bind=self.engine)
|
||||
# pylint: disable=invalid-name
|
||||
Session = scoped_session(session_factory)
|
||||
thread_one = threading.Thread(target=insert_player, args=(Session,))
|
||||
thread_two = threading.Thread(target=insert_players, args=(Session,))
|
||||
|
||||
logger = logging.getLogger("opentelemetry.sdk.trace")
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertLogs(logger, level="WARNING"):
|
||||
thread_one.start()
|
||||
thread_two.start()
|
||||
thread_one.join()
|
||||
thread_two.join()
|
||||
close_all_sessions()
|
||||
|
||||
spans = self.memory_exporter.get_finished_spans()
|
||||
self.assertEqual(len(spans), 5)
|
||||
|
Reference in New Issue
Block a user