11import json
2+ import os
23import socket
4+ import typing as t
35from codecs import open
4- from collections import namedtuple
5- from contextlib import closing , contextmanager
6+ from contextlib import contextmanager
67from os .path import abspath , dirname , isfile , join
8+ from pathlib import Path
79from random import choice
810from string import ascii_lowercase , ascii_uppercase , digits
911from time import sleep
1012
1113import docker
1214import mysql .connector
1315import pytest
16+ from _pytest ._py .path import LocalPath
17+ from _pytest .config import Config
18+ from _pytest .config .argparsing import Parser
19+ from _pytest .legacypath import TempdirFactory
1420from click .testing import CliRunner
21+ from docker import DockerClient
1522from docker .errors import NotFound
16- from mysql .connector import errorcode
23+ from docker .models .containers import Container
24+ from faker import Faker
25+ from mysql .connector import CMySQLConnection , MySQLConnection , errorcode
26+ from mysql .connector .pooling import PooledMySQLConnection
1727from requests import HTTPError
1828from sqlalchemy .exc import IntegrityError
29+ from sqlalchemy .orm import Session
1930from sqlalchemy_utils import database_exists , drop_database
2031
21- from .database import Database
22- from .factories import ArticleFactory , AuthorFactory , CrazyNameFactory , ImageFactory , MiscFactory , TagFactory
32+ from . import database , factories , models
2333
2434
25- def pytest_addoption (parser ):
35+ def pytest_addoption (parser : "Parser" ):
2636 parser .addoption (
2737 "--mysql-user" ,
2838 dest = "mysql_user" ,
@@ -78,9 +88,9 @@ def pytest_addoption(parser):
7888
7989
8090@pytest .fixture (scope = "session" , autouse = True )
81- def cleanup_hanged_docker_containers ():
91+ def cleanup_hanged_docker_containers () -> None :
8292 try :
83- client = docker .from_env ()
93+ client : DockerClient = docker .from_env ()
8494 for container in client .containers .list ():
8595 if container .name == "pytest_mysql_to_sqlite3" :
8696 container .kill ()
@@ -89,9 +99,9 @@ def cleanup_hanged_docker_containers():
8999 pass
90100
91101
92- def pytest_keyboard_interrupt ():
102+ def pytest_keyboard_interrupt () -> None :
93103 try :
94- client = docker .from_env ()
104+ client : DockerClient = docker .from_env ()
95105 for container in client .containers .list ():
96106 if container .name == "pytest_mysql_to_sqlite3" :
97107 container .kill ()
@@ -103,17 +113,17 @@ def pytest_keyboard_interrupt():
103113class Helpers :
104114 @staticmethod
105115 @contextmanager
106- def not_raises (exception ) :
116+ def not_raises (exception : t . Type [ Exception ]) -> t . Generator :
107117 try :
108118 yield
109119 except exception :
110120 raise pytest .fail ("DID RAISE {0}" .format (exception ))
111121
112122 @staticmethod
113123 @contextmanager
114- def session_scope (db ) :
124+ def session_scope (db : database . Database ) -> t . Generator :
115125 """Provide a transactional scope around a series of operations."""
116- session = db .Session ()
126+ session : Session = db .Session ()
117127 try :
118128 yield session
119129 session .commit ()
@@ -125,29 +135,37 @@ def session_scope(db):
125135
126136
127137@pytest .fixture
128- def helpers ():
138+ def helpers () -> t . Type [ Helpers ] :
129139 return Helpers
130140
131141
132142@pytest .fixture ()
133- def sqlite_database (tmpdir ) :
134- db_name = "" .join (choice (ascii_uppercase + ascii_lowercase + digits ) for _ in range (32 ))
135- return str (tmpdir .join ("{}.sqlite3" .format (db_name )))
143+ def sqlite_database (tmpdir : LocalPath ) -> t . Union [ str , Path , "os.PathLike[t.Any]" ] :
144+ db_name : str = "" .join (choice (ascii_uppercase + ascii_lowercase + digits ) for _ in range (32 ))
145+ return Path (tmpdir .join (Path ( "{}.sqlite3" .format (db_name ) )))
136146
137147
138- def is_port_in_use (port , host = "0.0.0.0" ):
148+ def is_port_in_use (port : int , host : str = "0.0.0.0" ) -> bool :
139149 with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
140150 return s .connect_ex ((host , port )) == 0
141151
142152
143- @pytest .fixture (scope = "session" )
144- def mysql_credentials (pytestconfig ):
145- MySQLCredentials = namedtuple ("MySQLCredentials" , ["user" , "password" , "host" , "port" , "database" ])
153+ class MySQLCredentials (t .NamedTuple ):
154+ """MySQL credentials."""
146155
147- db_credentials_file = abspath (join (dirname (__file__ ), "db_credentials.json" ))
156+ user : str
157+ password : str
158+ host : str
159+ port : int
160+ database : str
161+
162+
163+ @pytest .fixture (scope = "session" )
164+ def mysql_credentials (pytestconfig : Config ) -> MySQLCredentials :
165+ db_credentials_file : str = abspath (join (dirname (__file__ ), "db_credentials.json" ))
148166 if isfile (db_credentials_file ):
149167 with open (db_credentials_file , "r" , "utf-8" ) as fh :
150- db_credentials = json .load (fh )
168+ db_credentials : t . Dict [ str , t . Any ] = json .load (fh )
151169 return MySQLCredentials (
152170 user = db_credentials ["mysql_user" ],
153171 password = db_credentials ["mysql_password" ],
@@ -156,16 +174,13 @@ def mysql_credentials(pytestconfig):
156174 port = db_credentials ["mysql_port" ],
157175 )
158176
159- port = pytestconfig .getoption ("mysql_port" ) or 3306
177+ port : int = pytestconfig .getoption ("mysql_port" ) or 3306
160178 if pytestconfig .getoption ("use_docker" ):
161179 while is_port_in_use (port , pytestconfig .getoption ("mysql_host" )):
162180 if port >= 2 ** 16 - 1 :
163181 pytest .fail (
164182 "No ports appear to be available on the host {}" .format (pytestconfig .getoption ("mysql_host" ))
165183 )
166- raise ConnectionError (
167- "No ports appear to be available on the host {}" .format (pytestconfig .getoption ("mysql_host" ))
168- )
169184 port += 1
170185
171186 return MySQLCredentials (
@@ -178,11 +193,11 @@ def mysql_credentials(pytestconfig):
178193
179194
180195@pytest .fixture (scope = "session" )
181- def mysql_instance (mysql_credentials , pytestconfig ) :
182- container = None
183- mysql_connection = None
184- mysql_available = False
185- mysql_connection_retries = 15 # failsafe
196+ def mysql_instance (mysql_credentials : MySQLCredentials , pytestconfig : Config ) -> t . Iterator [ MySQLConnection ] :
197+ container : t . Optional [ Container ] = None
198+ mysql_connection : t . Optional [ t . Union [ PooledMySQLConnection , MySQLConnection , CMySQLConnection ]] = None
199+ mysql_available : bool = False
200+ mysql_connection_retries : int = 15 # failsafe
186201
187202 db_credentials_file = abspath (join (dirname (__file__ ), "db_credentials.json" ))
188203 if isfile (db_credentials_file ):
@@ -198,7 +213,6 @@ def mysql_instance(mysql_credentials, pytestconfig):
198213 client = docker .from_env ()
199214 except Exception as err :
200215 pytest .fail (str (err ))
201- raise
202216
203217 docker_mysql_image = pytestconfig .getoption ("docker_mysql_image" ) or "mysql:latest"
204218
@@ -208,7 +222,6 @@ def mysql_instance(mysql_credentials, pytestconfig):
208222 client .images .pull (docker_mysql_image )
209223 except (HTTPError , NotFound ) as err :
210224 pytest .fail (str (err ))
211- raise
212225
213226 container = client .containers .run (
214227 image = docker_mysql_image ,
@@ -256,17 +269,22 @@ def mysql_instance(mysql_credentials, pytestconfig):
256269 if not mysql_available and mysql_connection_retries <= 0 :
257270 raise ConnectionAbortedError ("Maximum MySQL connection retries exhausted! Are you sure MySQL is running?" )
258271
259- yield
272+ yield # type: ignore[misc]
260273
261- if use_docker :
274+ if use_docker and container is not None :
262275 container .kill ()
263276
264277
265278@pytest .fixture (scope = "session" )
266- def mysql_database (tmpdir_factory , mysql_instance , mysql_credentials , _session_faker ):
267- temp_image_dir = tmpdir_factory .mktemp ("images" )
268-
269- db = Database (
279+ def mysql_database (
280+ tmpdir_factory : TempdirFactory ,
281+ mysql_instance : MySQLConnection ,
282+ mysql_credentials : MySQLCredentials ,
283+ _session_faker : Faker ,
284+ ) -> t .Iterator [database .Database ]:
285+ temp_image_dir : LocalPath = tmpdir_factory .mktemp ("images" )
286+
287+ db : database .Database = database .Database (
270288 "mysql+mysqldb://{user}:{password}@{host}:{port}/{database}" .format (
271289 user = mysql_credentials .user ,
272290 password = mysql_credentials .password ,
@@ -278,13 +296,13 @@ def mysql_database(tmpdir_factory, mysql_instance, mysql_credentials, _session_f
278296
279297 with Helpers .session_scope (db ) as session :
280298 for _ in range (_session_faker .pyint (min_value = 12 , max_value = 24 )):
281- article = ArticleFactory ()
282- article .authors .append (AuthorFactory ())
283- article .tags .append (TagFactory ())
284- article .misc .append (MiscFactory ())
299+ article : models . Article = factories . ArticleFactory ()
300+ article .authors .append (factories . AuthorFactory ())
301+ article .tags .append (factories . TagFactory ())
302+ article .misc .append (factories . MiscFactory ())
285303 for _ in range (_session_faker .pyint (min_value = 1 , max_value = 4 )):
286304 article .images .append (
287- ImageFactory (
305+ factories . ImageFactory (
288306 path = join (
289307 str (temp_image_dir ),
290308 _session_faker .year (),
@@ -297,7 +315,7 @@ def mysql_database(tmpdir_factory, mysql_instance, mysql_credentials, _session_f
297315 session .add (article )
298316
299317 for _ in range (_session_faker .pyint (min_value = 12 , max_value = 24 )):
300- session .add (CrazyNameFactory ())
318+ session .add (factories . CrazyNameFactory ())
301319 try :
302320 session .commit ()
303321 except IntegrityError :
@@ -310,5 +328,5 @@ def mysql_database(tmpdir_factory, mysql_instance, mysql_credentials, _session_f
310328
311329
312330@pytest .fixture ()
313- def cli_runner ():
331+ def cli_runner () -> t . Iterator [ CliRunner ] :
314332 yield CliRunner ()
0 commit comments