| 
13 | 13 | import re  | 
14 | 14 | 
 
  | 
15 | 15 | from testcontainers.core.config import testcontainers_config as c  | 
16 |  | -from testcontainers.core.generic import DbContainer  | 
17 |  | -from testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs  | 
18 |  | -from trino.dbapi import connect  | 
 | 16 | +from testcontainers.core.generic import DockerContainer  | 
 | 17 | +from testcontainers.core.wait_strategies import LogMessageWaitStrategy  | 
19 | 18 | 
 
  | 
20 | 19 | 
 
  | 
21 |  | -class TrinoContainer(DbContainer):  | 
 | 20 | +class TrinoContainer(DockerContainer):  | 
22 | 21 |     def __init__(  | 
23 | 22 |         self,  | 
24 | 23 |         image="trinodb/trino:latest",  | 
25 | 24 |         user: str = "test",  | 
26 | 25 |         port: int = 8080,  | 
 | 26 | +        container_start_timeout: int = 30,  | 
27 | 27 |         **kwargs,  | 
28 | 28 |     ):  | 
29 | 29 |         super().__init__(image=image, **kwargs)  | 
30 | 30 |         self.user = user  | 
31 | 31 |         self.port = port  | 
32 | 32 |         self.with_exposed_ports(self.port)  | 
33 |  | - | 
34 |  | -    @wait_container_is_ready()  | 
35 |  | -    def _connect(self) -> None:  | 
36 |  | -        wait_for_logs(  | 
37 |  | -            self,  | 
38 |  | -            re.compile(".*======== SERVER STARTED ========.*", re.MULTILINE).search,  | 
39 |  | -            c.max_tries,  | 
40 |  | -            c.sleep_time,  | 
41 |  | -        )  | 
42 |  | -        conn = connect(  | 
43 |  | -            host=self.get_container_host_ip(),  | 
44 |  | -            port=self.get_exposed_port(self.port),  | 
45 |  | -            user=self.user,  | 
 | 33 | +        self.waiting_for(  | 
 | 34 | +            LogMessageWaitStrategy(re.compile(".*======== SERVER STARTED ========.*", re.MULTILINE))  | 
 | 35 | +            .with_poll_interval(c.sleep_time)  | 
 | 36 | +            .with_startup_timeout(container_start_timeout)  | 
46 | 37 |         )  | 
47 |  | -        cur = conn.cursor()  | 
48 |  | -        cur.execute("SELECT 1")  | 
49 |  | -        cur.fetchall()  | 
50 |  | -        conn.close()  | 
51 | 38 | 
 
  | 
52 | 39 |     def get_connection_url(self):  | 
53 | 40 |         return f"trino://{self.user}@{self.get_container_host_ip()}:{self.port}"  | 
 | 
0 commit comments