Skip to content

Commit 465c96a

Browse files
committed
feat(parser): Support for parsing binary AIG files.
2 parents 1f1d952 + d4ff113 commit 465c96a

File tree

10 files changed

+271
-56
lines changed

10 files changed

+271
-56
lines changed

aiger/aig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def to_aig(circ, *, allow_lazy=False) -> AIG:
284284
if isinstance(circ, pathlib.Path) and circ.is_file():
285285
circ = parser.load(circ)
286286
elif isinstance(circ, str):
287-
if circ.startswith('aag '):
287+
if circ.startswith('aag ') or circ.startswith('aig '):
288288
circ = parser.parse(circ) # Assume it is an AIGER string.
289289
else:
290290
circ = parser.load(circ) # Assume it is a file path.

aiger/parser.py

Lines changed: 152 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
1-
import io
21
import re
32
from collections import defaultdict
43
from functools import reduce
54
from typing import Mapping, List, Optional
5+
from uuid import uuid1
66

77
import attr
88
import funcy as fn
99
from bidict import bidict
10+
from sortedcontainers import SortedDict
1011
from toposort import toposort_flatten
11-
from uuid import uuid1
12-
from sortedcontainers import SortedList, SortedSet, SortedDict
1312

1413
import aiger as A
1514

1615

1716
@attr.s(auto_attribs=True, repr=False)
1817
class Header:
18+
binary_mode: bool
1919
max_var_index: int
2020
num_inputs: int
2121
num_latches: int
2222
num_outputs: int
2323
num_ands: int
2424

2525
def __repr__(self):
26-
return f"aag {self.max_var_index} {self.num_inputs} " \
27-
f"{self.num_latches} {self.num_outputs} {self.num_ands}"
26+
mode = 'aig' if self.binary_mode else 'aag'
27+
return f"{mode} {self.max_var_index} {self.num_inputs} " \
28+
f"{self.num_latches} {self.num_outputs} {self.num_ands}"
2829

2930

3031
NOT_DONE_PARSING_ERROR = "Parsing rules exhausted at line {}!\n{}"
@@ -38,6 +39,13 @@ class Latch:
3839
init: bool = attr.ib(converter=bool)
3940

4041

42+
@attr.s(auto_attribs=True, frozen=True)
43+
class And:
44+
lhs: int
45+
rhs0: int
46+
rhs1: int
47+
48+
4149
@attr.s(auto_attribs=True, frozen=True)
4250
class Symbol:
4351
kind: str
@@ -59,9 +67,10 @@ class SymbolTable:
5967
@attr.s(auto_attribs=True)
6068
class State:
6169
header: Optional[Header] = None
62-
inputs: List[str] = attr.ib(factory=list)
63-
outputs: List[str] = attr.ib(factory=list)
64-
latches: List[str] = attr.ib(factory=list)
70+
inputs: List[int] = attr.ib(factory=list)
71+
outputs: List[int] = attr.ib(factory=list)
72+
latches: List[Latch] = attr.ib(factory=list)
73+
ands: List[And] = attr.ib(factory=list)
6574
symbols: SymbolTable = attr.ib(factory=SymbolTable)
6675
comments: Optional[List[str]] = None
6776
nodes: SortedDict = attr.ib(factory=SortedDict)
@@ -78,29 +87,46 @@ def remaining_outputs(self):
7887
def remaining_inputs(self):
7988
return self.header.num_inputs - len(self.inputs)
8089

90+
@property
91+
def remaining_ands(self):
92+
return self.header.num_ands - len(self.ands)
8193

82-
HEADER_PATTERN = re.compile(r"aag (\d+) (\d+) (\d+) (\d+) (\d+)\n")
8394

95+
def _consume_stream(stream, delim) -> str:
96+
line = bytearray()
97+
ch = -1
98+
delim = ord(delim)
99+
while ch != delim:
100+
ch = next(stream, delim)
101+
line.append(ch)
102+
return line.decode('ascii')
84103

85-
def parse_header(state, line) -> bool:
104+
105+
HEADER_PATTERN = re.compile(r"(a[ai]g) (\d+) (\d+) (\d+) (\d+) (\d+)\n")
106+
107+
108+
def parse_header(state, stream) -> bool:
86109
if state.header is not None:
87110
return False
88111

112+
line = _consume_stream(stream, '\n')
89113
match = HEADER_PATTERN.match(line)
90114
if not match:
91-
raise ValueError(f"Failed to parse aag HEADER. {line}")
115+
raise ValueError(f"Failed to parse aag/aig HEADER. {line}")
92116

93117
try:
94-
ids = fn.lmap(int, match.groups())
118+
binary_mode = match.group(1) == 'aig'
119+
ids = fn.lmap(int, match.groups()[1:])
95120

96121
if any(x < 0 for x in ids):
97-
raise ValueError("Indicies must be positive!")
122+
raise ValueError("Indices must be positive!")
98123

99124
max_idx, nin, nlatch, nout, nand = ids
100125
if nin + nlatch + nand > max_idx:
101126
raise ValueError("Sum of claimed indices greater than max.")
102127

103128
state.header = Header(
129+
binary_mode=binary_mode,
104130
max_var_index=max_idx,
105131
num_inputs=nin,
106132
num_latches=nlatch,
@@ -116,21 +142,38 @@ def parse_header(state, line) -> bool:
116142
IO_PATTERN = re.compile(r"(\d+)\s*\n")
117143

118144

119-
def parse_input(state, line) -> bool:
120-
match = IO_PATTERN.match(line)
121-
122-
if match is None or state.remaining_inputs <= 0:
123-
return False
124-
lit = int(line)
145+
def _add_input(state, lit):
125146
state.inputs.append(lit)
126147
state.nodes[lit] = set()
127-
return True
128148

129149

130-
def parse_output(state, line) -> bool:
150+
def parse_input(state, stream) -> bool:
151+
if state.remaining_inputs <= 0:
152+
return False
153+
154+
if state.header.binary_mode:
155+
for lit in range(2, 2 * (state.header.num_inputs + 1), 2):
156+
_add_input(state, lit)
157+
return False
158+
159+
line = _consume_stream(stream, '\n')
131160
match = IO_PATTERN.match(line)
132-
if match is None or state.remaining_outputs <= 0:
161+
162+
if match is None:
163+
raise ValueError(f"Expecting an input: {line}")
164+
165+
_add_input(state, int(line))
166+
return True
167+
168+
169+
def parse_output(state, stream) -> bool:
170+
if state.remaining_outputs <= 0:
133171
return False
172+
173+
line = _consume_stream(stream, '\n')
174+
match = IO_PATTERN.match(line)
175+
if match is None:
176+
raise ValueError(f"Expecting an output: {line}")
134177
lit = int(line)
135178
state.outputs.append(lit)
136179
if lit & 1:
@@ -139,17 +182,28 @@ def parse_output(state, line) -> bool:
139182

140183

141184
LATCH_PATTERN = re.compile(r"(\d+) (\d+)(?: (\d+))?\n")
185+
LATCH_PATTERN_BINARY = re.compile(r"(\d+)(?: (\d+))?\n")
142186

143187

144-
def parse_latch(state, line) -> bool:
188+
def parse_latch(state, stream) -> bool:
145189
if state.remaining_latches <= 0:
146190
return False
147191

148-
match = LATCH_PATTERN.match(line)
149-
if match is None:
150-
raise ValueError("Expecting a latch: {line}")
192+
line = _consume_stream(stream, '\n')
193+
194+
if state.header.binary_mode:
195+
match = LATCH_PATTERN_BINARY.match(line)
196+
if match is None:
197+
raise ValueError(f"Expecting a latch: {line}")
198+
idx = state.header.num_inputs + len(state.latches) + 1
199+
lit = 2 * idx
200+
elems = (lit,) + match.groups()
201+
else:
202+
match = LATCH_PATTERN.match(line)
203+
if match is None:
204+
raise ValueError(f"Expecting a latch: {line}")
205+
elems = match.groups()
151206

152-
elems = match.groups()
153207
if elems[2] is None:
154208
elems = elems[:2] + (0,)
155209
elems = fn.lmap(int, elems)
@@ -165,30 +219,69 @@ def parse_latch(state, line) -> bool:
165219
AND_PATTERN = re.compile(r"(\d+) (\d+) (\d+)\s*\n")
166220

167221

168-
def parse_and(state, line) -> bool:
169-
if state.header.num_ands <= 0:
170-
return False
171-
172-
match = AND_PATTERN.match(line)
173-
if match is None:
174-
return False
175-
176-
elems = fn.lmap(int, match.groups())
177-
state.header.num_ands -= 1
178-
deps = set(elems[1:])
179-
state.nodes[elems[0]] = deps
222+
def _read_delta(data):
223+
ch = next(data)
224+
i = 0
225+
delta = 0
226+
while (ch & 0x80) != 0:
227+
if i == 5:
228+
raise ValueError("Invalid byte in delta encoding")
229+
delta |= (ch & 0x7f) << (7 * i)
230+
i += 1
231+
ch = next(data)
232+
if i == 5 and ch >= 8:
233+
raise ValueError("Invalid byte in delta encoding")
234+
235+
delta |= ch << (7 * i)
236+
return delta
237+
238+
239+
def _add_and(state, elems):
240+
lhs, rhs0, rhs1 = fn.lmap(int, elems)
241+
state.ands.append(And(lhs, rhs0, rhs1))
242+
deps = {rhs0, rhs1}
243+
state.nodes[lhs] = deps
180244
for dep in deps:
181245
if dep & 1:
182246
state.nodes[dep] = {dep ^ 1}
247+
248+
249+
def parse_and(state, stream) -> bool:
250+
if state.remaining_ands <= 0:
251+
return False
252+
253+
if state.header.binary_mode:
254+
idx = state.header.num_inputs + state.header.num_latches + len(state.ands) + 1
255+
lhs = 2 * idx
256+
delta = _read_delta(stream)
257+
if delta > lhs:
258+
raise ValueError(f"Invalid lhs {lhs} or delta {delta}")
259+
rhs0 = lhs - delta
260+
delta = _read_delta(stream)
261+
if delta > rhs0:
262+
raise ValueError(f"Invalid rhs0 {rhs0} or delta {delta}")
263+
rhs1 = rhs0 - delta
264+
else:
265+
line = _consume_stream(stream, '\n')
266+
match = AND_PATTERN.match(line)
267+
if match is None:
268+
raise ValueError(f"Expecting an and: {line}")
269+
lhs, rhs0, rhs1 = match.groups()
270+
271+
_add_and(state, (lhs, rhs0, rhs1))
183272
return True
184273

185274

186275
SYM_PATTERN = re.compile(r"([ilo])(\d+) (.*)\s*\n")
187276

188277

189-
def parse_symbol(state, line) -> bool:
278+
def parse_symbol(state, stream) -> bool:
279+
line = _consume_stream(stream, '\n')
190280
match = SYM_PATTERN.match(line)
191281
if match is None:
282+
# We might have consumed the 'c' starting the comments section
283+
if line.rstrip() == 'c':
284+
state.comments = []
192285
return False
193286

194287
kind, idx, name = match.groups()
@@ -202,7 +295,8 @@ def parse_symbol(state, line) -> bool:
202295
return True
203296

204297

205-
def parse_comment(state, line) -> bool:
298+
def parse_comment(state, stream) -> bool:
299+
line = _consume_stream(stream, '\n')
206300
if state.comments is not None:
207301
state.comments.append(line.rstrip())
208302
elif line.rstrip() == 'c':
@@ -227,25 +321,30 @@ def finish_table(table, keys):
227321
return {table[i]: key for i, key in enumerate(keys)}
228322

229323

230-
def parse(lines, to_aig: bool = True):
231-
if isinstance(lines, str):
232-
lines = io.StringIO(lines)
324+
def parse(stream):
325+
if isinstance(stream, list):
326+
stream = ''.join(stream)
327+
if isinstance(stream, str):
328+
stream = bytes(stream, 'ascii')
329+
stream = iter(stream)
233330

234331
state = State()
235332
parsers = parse_seq()
236333
parser = next(parsers)
237334

238-
for i, line in enumerate(lines):
239-
while not parser(state, line):
335+
i = 0
336+
while stream.__length_hint__() > 0:
337+
i += 1
338+
while not parser(state, stream):
240339
parser = next(parsers, None)
241340

242341
if parser is None:
243-
raise ValueError(NOT_DONE_PARSING_ERROR.format(i + 1, state))
342+
raise ValueError(NOT_DONE_PARSING_ERROR.format(i, state))
244343

245344
if parser not in (parse_header, parse_output, parse_comment, parse_symbol):
246345
raise ValueError(DONE_PARSING_ERROR.format(state))
247346

248-
assert state.header.num_ands == 0
347+
assert state.remaining_ands == 0
249348
assert state.remaining_inputs == 0
250349
assert state.remaining_outputs == 0
251350
assert state.remaining_latches == 0
@@ -260,6 +359,7 @@ def parse(lines, to_aig: bool = True):
260359

261360
# Create expression DAG.
262361
latch_ids = {latch.id: name for name, latch in latches.items()}
362+
and_ids = {and_.lhs: and_ for and_ in state.ands}
263363
lit2expr = {0: A.aig.ConstFalse()}
264364
for lit in toposort_flatten(state.nodes):
265365
if lit == 0:
@@ -272,6 +372,7 @@ def parse(lines, to_aig: bool = True):
272372
elif lit & 1:
273373
lit2expr[lit] = A.aig.Inverter(lit2expr[lit & -2])
274374
else:
375+
assert lit in and_ids
275376
nodes = [lit2expr[lit2] for lit2 in state.nodes[lit]]
276377
lit2expr[lit] = reduce(A.aig.AndGate, nodes)
277378

@@ -284,9 +385,9 @@ def parse(lines, to_aig: bool = True):
284385
)
285386

286387

287-
def load(path: str, to_aig: bool = True):
288-
with open(path, 'r') as f:
289-
return parse(''.join(f.readlines()), to_aig=to_aig)
388+
def load(path: str):
389+
with open(path, 'rb') as f:
390+
return parse(f.read())
290391

291392

292393
__all__ = ['load', 'parse']

0 commit comments

Comments
 (0)