|
4 | 4 | from collections import Counter |
5 | 5 | import logging |
6 | 6 | import random |
| 7 | +import textwrap |
7 | 8 |
|
8 | 9 | import rdflib |
9 | 10 | from rdflib import URIRef |
|
13 | 14 | from gp_learner import mutate_increase_dist |
14 | 15 | from gp_learner import mutate_merge_var |
15 | 16 | from gp_learner import mutate_simplify_pattern |
16 | | -from gp_learner import mutate_deep_narrow_path |
17 | 17 | from graph_pattern import GraphPattern |
18 | 18 | from graph_pattern import SOURCE_VAR |
19 | 19 | from graph_pattern import TARGET_VAR |
@@ -109,33 +109,36 @@ def test_mutate_merge_var(): |
109 | 109 | assert False, "merge never reached one of the cases: %s" % cases |
110 | 110 |
|
111 | 111 |
|
112 | | -def test_mutate_deep_narrow_path(): |
113 | | - p = Variable('p') |
114 | | - gp = GraphPattern([ |
115 | | - (SOURCE_VAR, p, TARGET_VAR) |
116 | | - ]) |
117 | | - child = mutate_deep_narrow_path(gp) |
118 | | - assert gp == child or len(child) > len(gp) |
119 | | - print(gp) |
120 | | - print(child) |
121 | | - |
| 112 | +def test_deep_narrow_path_query(): |
| 113 | + node_var = Variable('node_var') |
| 114 | + edge_var = Variable('edge_var') |
| 115 | + gtps = [ |
| 116 | + (dbp['Barrel'], dbp['Wine']), |
| 117 | + (dbp['Barrister'], dbp['Law']), |
| 118 | + (dbp['Beak'], dbp['Bird']), |
| 119 | + (dbp['Blanket'], dbp['Bed']), |
| 120 | + ] |
122 | 121 |
|
123 | | -def test_to_find_edge_var_for_narrow_path_query(): |
124 | | - node_var = Variable('node_variable') |
125 | | - edge_var = Variable('edge_variable') |
126 | 122 | gp = GraphPattern([ |
127 | 123 | (node_var, edge_var, SOURCE_VAR), |
128 | 124 | (SOURCE_VAR, wikilink, TARGET_VAR) |
129 | 125 | ]) |
130 | | - filter_node_count = 10 |
131 | | - filter_edge_count = 1 |
132 | | - limit_res = 32 |
133 | | - vars_ = {SOURCE_VAR,TARGET_VAR} |
134 | | - res = GraphPattern.to_find_edge_var_for_narrow_path_query(gp, edge_var, node_var, |
135 | | - vars_, filter_node_count, |
136 | | - filter_edge_count, limit_res) |
137 | | - print(gp) |
138 | | - print(res) |
| 126 | + |
| 127 | + vars_ = (SOURCE_VAR, TARGET_VAR) |
| 128 | + res = gp.to_deep_narrow_path_query( |
| 129 | + edge_var, node_var, vars_, {vars_: gtps}, |
| 130 | + limit=32, |
| 131 | + max_node_count=10, |
| 132 | + min_edge_count=2, |
| 133 | + ).strip() |
| 134 | + doc = gp.to_deep_narrow_path_query.__doc__ |
| 135 | + doc_str_example_query = "\n".join([ |
| 136 | + l for l in doc.splitlines() |
| 137 | + if l.startswith(' ') |
| 138 | + ]) |
| 139 | + doc_str_example_query = textwrap.dedent(doc_str_example_query) |
| 140 | + assert res == doc_str_example_query, \ |
| 141 | + "res:\n%s\n\ndoes not look like:\n\n%s" % (res, doc_str_example_query) |
139 | 142 |
|
140 | 143 |
|
141 | 144 | def test_simplify_pattern(): |
@@ -303,5 +306,4 @@ def test_gtp_scores(): |
303 | 306 |
|
304 | 307 |
|
305 | 308 | if __name__ == '__main__': |
306 | | - # test_mutate_deep_narrow_path() |
307 | | - test_to_find_edge_var_for_narrow_path_query() |
| 309 | + test_deep_narrow_path_query() |
0 commit comments