Skip to content

Commit cf92581

Browse files
committed
Support loading *.aig files in binary format
1 parent 1f1d952 commit cf92581

File tree

2 files changed

+170
-46
lines changed

2 files changed

+170
-46
lines changed

aiger/parser.py

Lines changed: 131 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
1-
import io
21
import re
32
from collections import defaultdict
43
from functools import reduce
54
from typing import Mapping, List, Optional
5+
from uuid import uuid1
66

77
import attr
88
import funcy as fn
99
from bidict import bidict
10+
from sortedcontainers import SortedDict
1011
from toposort import toposort_flatten
11-
from uuid import uuid1
12-
from sortedcontainers import SortedList, SortedSet, SortedDict
1312

1413
import aiger as A
1514

1615

1716
@attr.s(auto_attribs=True, repr=False)
1817
class Header:
18+
binary_mode: bool
1919
max_var_index: int
2020
num_inputs: int
2121
num_latches: int
2222
num_outputs: int
2323
num_ands: int
2424

2525
def __repr__(self):
26-
return f"aag {self.max_var_index} {self.num_inputs} " \
27-
f"{self.num_latches} {self.num_outputs} {self.num_ands}"
26+
mode = 'aig' if self.binary_mode else 'aag'
27+
return f"{mode} {self.max_var_index} {self.num_inputs} " \
28+
f"{self.num_latches} {self.num_outputs} {self.num_ands}"
2829

2930

3031
NOT_DONE_PARSING_ERROR = "Parsing rules exhausted at line {}!\n{}"
@@ -79,28 +80,41 @@ def remaining_inputs(self):
7980
return self.header.num_inputs - len(self.inputs)
8081

8182

82-
HEADER_PATTERN = re.compile(r"aag (\d+) (\d+) (\d+) (\d+) (\d+)\n")
83+
def _consume_stream(stream, delim) -> str:
84+
line = bytearray()
85+
ch = -1
86+
delim = ord(delim)
87+
while ch != delim:
88+
ch = next(stream, delim)
89+
line.append(ch)
90+
return line.decode('ascii')
91+
8392

93+
HEADER_PATTERN = re.compile(r"(a[ai]g) (\d+) (\d+) (\d+) (\d+) (\d+)\n")
8494

85-
def parse_header(state, line) -> bool:
95+
96+
def parse_header(state, stream) -> bool:
8697
if state.header is not None:
8798
return False
8899

100+
line = _consume_stream(stream, '\n')
89101
match = HEADER_PATTERN.match(line)
90102
if not match:
91-
raise ValueError(f"Failed to parse aag HEADER. {line}")
103+
raise ValueError(f"Failed to parse aag/aig HEADER. {line}")
92104

93105
try:
94-
ids = fn.lmap(int, match.groups())
106+
binary_mode = match.group(1) == 'aig'
107+
ids = fn.lmap(int, match.groups()[1:])
95108

96109
if any(x < 0 for x in ids):
97-
raise ValueError("Indicies must be positive!")
110+
raise ValueError("Indices must be positive!")
98111

99112
max_idx, nin, nlatch, nout, nand = ids
100113
if nin + nlatch + nand > max_idx:
101114
raise ValueError("Sum of claimed indices greater than max.")
102115

103116
state.header = Header(
117+
binary_mode=binary_mode,
104118
max_var_index=max_idx,
105119
num_inputs=nin,
106120
num_latches=nlatch,
@@ -116,21 +130,38 @@ def parse_header(state, line) -> bool:
116130
IO_PATTERN = re.compile(r"(\d+)\s*\n")
117131

118132

119-
def parse_input(state, line) -> bool:
120-
match = IO_PATTERN.match(line)
121-
122-
if match is None or state.remaining_inputs <= 0:
123-
return False
124-
lit = int(line)
133+
def _add_input(state, lit):
125134
state.inputs.append(lit)
126135
state.nodes[lit] = set()
127-
return True
128136

129137

130-
def parse_output(state, line) -> bool:
138+
def parse_input(state, stream) -> bool:
139+
if state.remaining_inputs <= 0:
140+
return False
141+
142+
if state.header.binary_mode:
143+
for lit in range(2, 2 * (state.header.num_inputs + 1), 2):
144+
_add_input(state, lit)
145+
return False
146+
147+
line = _consume_stream(stream, '\n')
131148
match = IO_PATTERN.match(line)
132-
if match is None or state.remaining_outputs <= 0:
149+
150+
if match is None:
151+
raise ValueError(f"Expecting an input: {line}")
152+
153+
_add_input(state, int(line))
154+
return True
155+
156+
157+
def parse_output(state, stream) -> bool:
158+
if state.remaining_outputs <= 0:
133159
return False
160+
161+
line = _consume_stream(stream, '\n')
162+
match = IO_PATTERN.match(line)
163+
if match is None:
164+
raise ValueError(f"Expecting an output: {line}")
134165
lit = int(line)
135166
state.outputs.append(lit)
136167
if lit & 1:
@@ -139,17 +170,28 @@ def parse_output(state, line) -> bool:
139170

140171

141172
LATCH_PATTERN = re.compile(r"(\d+) (\d+)(?: (\d+))?\n")
173+
LATCH_PATTERN_BINARY = re.compile(r"(\d+)(?: (\d+))?\n")
142174

143175

144-
def parse_latch(state, line) -> bool:
176+
def parse_latch(state, stream) -> bool:
145177
if state.remaining_latches <= 0:
146178
return False
147179

148-
match = LATCH_PATTERN.match(line)
149-
if match is None:
150-
raise ValueError("Expecting a latch: {line}")
180+
line = _consume_stream(stream, '\n')
181+
182+
if state.header.binary_mode:
183+
match = LATCH_PATTERN_BINARY.match(line)
184+
if match is None:
185+
raise ValueError(f"Expecting a latch: {line}")
186+
idx = state.header.num_inputs + len(state.latches) + 1
187+
lit = 2 * idx
188+
elems = (lit,) + match.groups()
189+
else:
190+
match = LATCH_PATTERN.match(line)
191+
if match is None:
192+
raise ValueError(f"Expecting a latch: {line}")
193+
elems = match.groups()
151194

152-
elems = match.groups()
153195
if elems[2] is None:
154196
elems = elems[:2] + (0,)
155197
elems = fn.lmap(int, elems)
@@ -165,30 +207,71 @@ def parse_latch(state, line) -> bool:
165207
AND_PATTERN = re.compile(r"(\d+) (\d+) (\d+)\s*\n")
166208

167209

168-
def parse_and(state, line) -> bool:
169-
if state.header.num_ands <= 0:
170-
return False
210+
def _read_delta(data):
211+
ch = next(data)
212+
i = 0
213+
delta = 0
214+
while (ch & 0x80) != 0:
215+
if i == 5:
216+
raise ValueError("Invalid byte in delta encoding")
217+
delta |= (ch & 0x7f) << (7 * i)
218+
i += 1
219+
ch = next(data)
220+
if i == 5 and ch >= 8:
221+
raise ValueError("Invalid byte in delta encoding")
222+
223+
delta |= ch << (7 * i)
224+
return delta
171225

172-
match = AND_PATTERN.match(line)
173-
if match is None:
174-
return False
175226

176-
elems = fn.lmap(int, match.groups())
227+
def _add_and(state, elems):
228+
elems = fn.lmap(int, elems)
177229
state.header.num_ands -= 1
178230
deps = set(elems[1:])
179231
state.nodes[elems[0]] = deps
180232
for dep in deps:
181233
if dep & 1:
182234
state.nodes[dep] = {dep ^ 1}
235+
236+
237+
def parse_and(state, stream) -> bool:
238+
if state.header.num_ands <= 0:
239+
return False
240+
241+
if state.header.binary_mode:
242+
lhs = 2 * (state.header.num_inputs + state.header.num_latches)
243+
for i in range(state.header.num_ands):
244+
lhs += 2
245+
delta = _read_delta(stream)
246+
if delta > lhs:
247+
raise ValueError(f"Invalid lhs {lhs} or delta {delta}")
248+
rhs0 = lhs - delta
249+
delta = _read_delta(stream)
250+
if delta > rhs0:
251+
raise ValueError(f"Invalid rhs0 {rhs0} or delta {delta}")
252+
rhs1 = rhs0 - delta
253+
_add_and(state, (lhs, rhs0, rhs1))
254+
255+
else:
256+
line = _consume_stream(stream, '\n')
257+
match = AND_PATTERN.match(line)
258+
if match is None:
259+
raise ValueError(f"Expecting an and: {line}")
260+
261+
_add_and(state, match.groups())
183262
return True
184263

185264

186265
SYM_PATTERN = re.compile(r"([ilo])(\d+) (.*)\s*\n")
187266

188267

189-
def parse_symbol(state, line) -> bool:
268+
def parse_symbol(state, stream) -> bool:
269+
line = _consume_stream(stream, '\n')
190270
match = SYM_PATTERN.match(line)
191271
if match is None:
272+
# We might have consumed the 'c' starting the comments section
273+
if line.rstrip() == 'c':
274+
state.comments = []
192275
return False
193276

194277
kind, idx, name = match.groups()
@@ -202,7 +285,8 @@ def parse_symbol(state, line) -> bool:
202285
return True
203286

204287

205-
def parse_comment(state, line) -> bool:
288+
def parse_comment(state, stream) -> bool:
289+
line = _consume_stream(stream, '\n')
206290
if state.comments is not None:
207291
state.comments.append(line.rstrip())
208292
elif line.rstrip() == 'c':
@@ -227,20 +311,25 @@ def finish_table(table, keys):
227311
return {table[i]: key for i, key in enumerate(keys)}
228312

229313

230-
def parse(lines, to_aig: bool = True):
231-
if isinstance(lines, str):
232-
lines = io.StringIO(lines)
314+
def parse(stream):
315+
if isinstance(stream, list):
316+
stream = ''.join(stream)
317+
if isinstance(stream, str):
318+
stream = bytes(stream, 'ascii')
319+
stream = iter(stream)
233320

234321
state = State()
235322
parsers = parse_seq()
236323
parser = next(parsers)
237324

238-
for i, line in enumerate(lines):
239-
while not parser(state, line):
325+
i = 0
326+
while stream.__length_hint__() > 0:
327+
i += 1
328+
while not parser(state, stream):
240329
parser = next(parsers, None)
241330

242331
if parser is None:
243-
raise ValueError(NOT_DONE_PARSING_ERROR.format(i + 1, state))
332+
raise ValueError(NOT_DONE_PARSING_ERROR.format(i, state))
244333

245334
if parser not in (parse_header, parse_output, parse_comment, parse_symbol):
246335
raise ValueError(DONE_PARSING_ERROR.format(state))
@@ -284,9 +373,9 @@ def parse(lines, to_aig: bool = True):
284373
)
285374

286375

287-
def load(path: str, to_aig: bool = True):
288-
with open(path, 'r') as f:
289-
return parse(''.join(f.readlines()), to_aig=to_aig)
376+
def load(path: str):
377+
with open(path, 'rb') as f:
378+
return parse(f.read())
290379

291380

292381
__all__ = ['load', 'parse']

tests/test_parser.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,45 @@ def test_io_order():
142142
o1 oa
143143
""")
144144

145-
146145
data = {'a': False, 'b': True}
147146
assert circ1(data) \
148-
== circ2(data) \
149-
== circ3(data) \
150-
== circ4(data)
147+
== circ2(data) \
148+
== circ3(data) \
149+
== circ4(data)
150+
151151

152+
# ----------- BINARY FILE PARSER TESTS ----------------
153+
154+
@given(st.data())
155+
def test_smoke1_aig(data):
156+
circ1 = aigp.parse(TEST1)
157+
circ2 = aigp.load("tests/aig/test1.aig")
158+
test_input = {f'{i}': data.draw(st.booleans()) for i in circ1.inputs}
159+
assert circ1(test_input) == circ2(test_input)
160+
161+
162+
@given(st.data())
163+
def test_smoke2_aig(data):
164+
circ1 = aigp.parse(TEST2)
165+
circ2 = aigp.load("tests/aig/test2.aig")
166+
test_input = {f'{i}': data.draw(st.booleans()) for i in circ1.inputs}
167+
assert circ1(test_input) == circ2(test_input)
168+
169+
170+
def test_mutex_example_smoke_aig():
171+
aigp.load('tests/aig/mutex_converted.aig')
172+
173+
174+
def test_degenerate_smoke_aig():
175+
import aiger as A
176+
177+
expr = A.BoolExpr(A.load("tests/aig/test_degenerate1.aig"))
178+
assert expr({}) is False
179+
expr = A.BoolExpr(A.load("tests/aig/test_degenerate2.aig"))
180+
assert expr({}) is True
181+
circ = A.load("tests/aig/test_degenerate3.aig")
182+
assert len(circ.node_map) == 0
183+
assert circ.inputs == circ.outputs == circ.latches == set()
184+
185+
circ = A.load("tests/aig/test_degenerate4.aig")
186+
assert not any(circ({})[0].values())

0 commit comments

Comments
 (0)