diff --git a/clickhouse_sqlalchemy/drivers/base.py b/clickhouse_sqlalchemy/drivers/base.py index b93c6aa0..b99c29e3 100644 --- a/clickhouse_sqlalchemy/drivers/base.py +++ b/clickhouse_sqlalchemy/drivers/base.py @@ -53,6 +53,10 @@ '_lowcardinality': types.LowCardinality, '_tuple': types.Tuple, '_map': types.Map, + 'Point': types.Point, + 'Ring': types.Ring, + 'Polygon': types.Polygon, + 'MultiPolygon': types.MultiPolygon } diff --git a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py index 26647841..3a91c784 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py @@ -131,3 +131,15 @@ def visit_map(self, type_, **kw): self.process(key_type, **kw), self.process(value_type, **kw) ) + + def visit_point(self, type_, **kw): + return 'Point' + + def visit_ring(self, type_, **kw): + return 'Ring' + + def visit_polygon(self, type_, *kw): + return 'Polygon' + + def visit_multipolygon(self, type_, *kw): + return 'MultiPolygon' diff --git a/clickhouse_sqlalchemy/drivers/http/escaper.py b/clickhouse_sqlalchemy/drivers/http/escaper.py index 10ebf991..cd7c2647 100644 --- a/clickhouse_sqlalchemy/drivers/http/escaper.py +++ b/clickhouse_sqlalchemy/drivers/http/escaper.py @@ -25,7 +25,10 @@ def escape_string(self, value): def escape(self, parameters): if isinstance(parameters, dict): return {k: self.escape_item(v) for k, v in parameters.items()} - elif isinstance(parameters, (list, tuple)): + elif isinstance(parameters, tuple): + return "(" + ",".join( + [str(self.escape_item(x)) for x in parameters]) + ")" + elif isinstance(parameters, list): return "[" + ",".join( [str(self.escape_item(x)) for x in parameters]) + "]" else: @@ -62,7 +65,11 @@ def escape_item(self, item): return self.escape_decimal(item) elif isinstance(item, str): return self.escape_string(item) - elif isinstance(item, (list, tuple)): + elif isinstance(item, tuple): + return "(" + ", ".join( + [str(self.escape_item(x)) for x in item] + ) + ")" + elif isinstance(item, list): return "[" + ", ".join( [str(self.escape_item(x)) for x in item] ) + "]" diff --git a/clickhouse_sqlalchemy/drivers/http/transport.py b/clickhouse_sqlalchemy/drivers/http/transport.py index ad01ea56..cb24810e 100644 --- a/clickhouse_sqlalchemy/drivers/http/transport.py +++ b/clickhouse_sqlalchemy/drivers/http/transport.py @@ -47,6 +47,28 @@ def nothing_converter(x): return None +POINT_RE = re.compile(r'(-?\d*\.?\d+)') +RING_RE = re.compile(r'(\(.*?\))') +POLYGON_RE = re.compile(r'(\[.*?\])') +MULTIPOLYGON_RE = re.compile(r'\[\[.*?\]\]') + + +def point_converter(x): + return tuple([float(f) for f in POINT_RE.findall(x[1:-1])]) + + +def ring_converter(x): + return [point_converter(f) for f in RING_RE.findall(x[1:-1])] + + +def polygon_converter(x): + return [ring_converter(f) for f in POLYGON_RE.findall(x[1:-1])] + + +def multipolygon_converter(x): + return [polygon_converter(f) for f in MULTIPOLYGON_RE.findall(x[1:-1])] + + converters = { 'Int8': int, 'UInt8': int, @@ -70,6 +92,10 @@ def nothing_converter(x): 'IPv6': IPv6Address, 'Nullable': nullable_converter, 'Nothing': nothing_converter, + 'Point': point_converter, + 'Ring': ring_converter, + 'Polygon': polygon_converter, + 'MultiPolygon': multipolygon_converter } diff --git a/clickhouse_sqlalchemy/types/__init__.py b/clickhouse_sqlalchemy/types/__init__.py index 502e8a0f..0f9ae1d8 100644 --- a/clickhouse_sqlalchemy/types/__init__.py +++ b/clickhouse_sqlalchemy/types/__init__.py @@ -33,6 +33,10 @@ 'Nested', 'Tuple', 'Map', + 'Point', + 'Ring', + 'Polygon', + 'MultiPolygon' ] from .common import String @@ -69,3 +73,7 @@ from .ip import IPv4 from .ip import IPv6 from .nested import Nested +from .geo import Point +from .geo import Ring +from .geo import Polygon +from .geo import MultiPolygon diff --git a/clickhouse_sqlalchemy/types/geo.py b/clickhouse_sqlalchemy/types/geo.py new file mode 100644 index 00000000..06261b3d --- /dev/null +++ b/clickhouse_sqlalchemy/types/geo.py @@ -0,0 +1,17 @@ +from sqlalchemy import types + + +class Point(types.UserDefinedType): + __visit_name__ = "point" + + +class Ring(types.UserDefinedType): + __visit_name__ = "ring" + + +class Polygon(types.UserDefinedType): + __visit_name__ = "polygon" + + +class MultiPolygon(types.UserDefinedType): + __visit_name__ = "multipolygon" diff --git a/setup.py b/setup.py index 10352d6c..13bd924d 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ def read_version(): 'sqlalchemy>=1.4.24,<1.5', 'greenlet>=2.0.1', 'requests', - 'clickhouse-driver>=0.1.2', + 'clickhouse-driver>=0.2.4', 'asynch>=0.2.2', ], # Registering `clickhouse` as dialect. diff --git a/tests/types/test_geo.py b/tests/types/test_geo.py new file mode 100644 index 00000000..1c07c79a --- /dev/null +++ b/tests/types/test_geo.py @@ -0,0 +1,114 @@ +from sqlalchemy import Column +from sqlalchemy.sql.ddl import CreateTable + +from clickhouse_sqlalchemy import types, engines, Table +from tests.testcase import BaseTestCase +from tests.util import with_native_and_http_sessions + + +@with_native_and_http_sessions +class GeoPointTestCase(BaseTestCase): + table = Table( + 'test', BaseTestCase.metadata(), + Column('p', types.Point), + engines.Memory() + ) + + def test_create_table(self): + self.assertEqual( + self.compile(CreateTable(self.table)), + 'CREATE TABLE test (p Point) ENGINE = Memory' + ) + + def test_select_insert(self): + a = (10.1, 12.3) + + with self.create_table(self.table): + self.session.execute(self.table.insert(), [{'p': a}]) + qres = self.session.query(self.table.c.p) + res = qres.scalar() + self.assertEqual(res, a) + + def test_select_where_point(self): + a = (10.1, 12.3) + + with self.create_table(self.table): + self.session.execute(self.table.insert(), [{'p': a}]) + + self.assertEqual(self.session.query(self.table.c.p).filter( + self.table.c.p == (10.1, 12.3)).scalar(), a) + + +@with_native_and_http_sessions +class GeoRingTestCase(BaseTestCase): + table = Table( + 'test', BaseTestCase.metadata(), + Column('r', types.Ring), + engines.Memory() + ) + + def test_create_table(self): + self.assertEqual( + self.compile(CreateTable(self.table)), + 'CREATE TABLE test (r Ring) ENGINE = Memory' + ) + + def test_select_insert(self): + a = [(0, 0), (10, 0), (10, 10), (0, 10)] + + with self.create_table(self.table): + self.session.execute(self.table.insert(), [{'r': a}]) + qres = self.session.query(self.table.c.r) + res = qres.scalar() + self.assertEqual(res, a) + + +@with_native_and_http_sessions +class GeoPolygonTestCase(BaseTestCase): + table = Table( + 'test', BaseTestCase.metadata(), + Column('pg', types.Polygon), + engines.Memory() + ) + + def test_create_table(self): + self.assertEqual( + self.compile(CreateTable(self.table)), + 'CREATE TABLE test (pg Polygon) ENGINE = Memory' + ) + + def test_select_insert(self): + a = [[(20, 20), (50, 20), (50, 50), (20, 50)], + [(30, 30), (50, 50), (50, 30)]] + + with self.create_table(self.table): + self.session.execute(self.table.insert(), [{'pg': a}]) + qres = self.session.query(self.table.c.pg) + res = qres.scalar() + self.assertEqual(res, a) + + +@with_native_and_http_sessions +class GeoMultiPolygonTestCase(BaseTestCase): + table = Table( + 'test', BaseTestCase.metadata(), + Column('mpg', types.MultiPolygon), + engines.Memory() + ) + + def test_create_table(self): + self.assertEqual( + self.compile(CreateTable(self.table)), + 'CREATE TABLE test (mpg MultiPolygon) ENGINE = Memory' + ) + + def test_select_insert(self): + a = [[[(0, 0), (10, 0), (10, 10), (0, 10)]], + [[(20, 20), (50, 20), (50, 50), (20, 50)], + [(30, 30), (50, 50), (50, 30)]]] + + with self.create_table(self.table): + self.session.execute(self.table.insert(), [{'mpg': a}]) + qres = self.session.query(self.table.c.mpg) + res = qres.scalar() + self.assertEqual(res, a)