1- import io
21import re
32from collections import defaultdict
43from functools import reduce
54from typing import Mapping , List , Optional
5+ from uuid import uuid1
66
77import attr
88import funcy as fn
99from bidict import bidict
10+ from sortedcontainers import SortedDict
1011from toposort import toposort_flatten
11- from uuid import uuid1
12- from sortedcontainers import SortedList , SortedSet , SortedDict
1312
1413import aiger as A
1514
1615
1716@attr .s (auto_attribs = True , repr = False )
1817class Header :
18+ binary_mode : bool
1919 max_var_index : int
2020 num_inputs : int
2121 num_latches : int
2222 num_outputs : int
2323 num_ands : int
2424
2525 def __repr__ (self ):
26- return f"aag { self .max_var_index } { self .num_inputs } " \
27- f"{ self .num_latches } { self .num_outputs } { self .num_ands } "
26+ mode = 'aig' if self .binary_mode else 'aag'
27+ return f"{ mode } { self .max_var_index } { self .num_inputs } " \
28+ f"{ self .num_latches } { self .num_outputs } { self .num_ands } "
2829
2930
3031NOT_DONE_PARSING_ERROR = "Parsing rules exhausted at line {}!\n {}"
@@ -38,6 +39,13 @@ class Latch:
3839 init : bool = attr .ib (converter = bool )
3940
4041
42+ @attr .s (auto_attribs = True , frozen = True )
43+ class And :
44+ lhs : int
45+ rhs0 : int
46+ rhs1 : int
47+
48+
4149@attr .s (auto_attribs = True , frozen = True )
4250class Symbol :
4351 kind : str
@@ -59,9 +67,10 @@ class SymbolTable:
5967@attr .s (auto_attribs = True )
6068class State :
6169 header : Optional [Header ] = None
62- inputs : List [str ] = attr .ib (factory = list )
63- outputs : List [str ] = attr .ib (factory = list )
64- latches : List [str ] = attr .ib (factory = list )
70+ inputs : List [int ] = attr .ib (factory = list )
71+ outputs : List [int ] = attr .ib (factory = list )
72+ latches : List [Latch ] = attr .ib (factory = list )
73+ ands : List [And ] = attr .ib (factory = list )
6574 symbols : SymbolTable = attr .ib (factory = SymbolTable )
6675 comments : Optional [List [str ]] = None
6776 nodes : SortedDict = attr .ib (factory = SortedDict )
@@ -78,29 +87,46 @@ def remaining_outputs(self):
7887 def remaining_inputs (self ):
7988 return self .header .num_inputs - len (self .inputs )
8089
90+ @property
91+ def remaining_ands (self ):
92+ return self .header .num_ands - len (self .ands )
8193
82- HEADER_PATTERN = re .compile (r"aag (\d+) (\d+) (\d+) (\d+) (\d+)\n" )
8394
95+ def _consume_stream (stream , delim ) -> str :
96+ line = bytearray ()
97+ ch = - 1
98+ delim = ord (delim )
99+ while ch != delim :
100+ ch = next (stream , delim )
101+ line .append (ch )
102+ return line .decode ('ascii' )
84103
85- def parse_header (state , line ) -> bool :
104+
105+ HEADER_PATTERN = re .compile (r"(a[ai]g) (\d+) (\d+) (\d+) (\d+) (\d+)\n" )
106+
107+
108+ def parse_header (state , stream ) -> bool :
86109 if state .header is not None :
87110 return False
88111
112+ line = _consume_stream (stream , '\n ' )
89113 match = HEADER_PATTERN .match (line )
90114 if not match :
91- raise ValueError (f"Failed to parse aag HEADER. { line } " )
115+ raise ValueError (f"Failed to parse aag/aig HEADER. { line } " )
92116
93117 try :
94- ids = fn .lmap (int , match .groups ())
118+ binary_mode = match .group (1 ) == 'aig'
119+ ids = fn .lmap (int , match .groups ()[1 :])
95120
96121 if any (x < 0 for x in ids ):
97- raise ValueError ("Indicies must be positive!" )
122+ raise ValueError ("Indices must be positive!" )
98123
99124 max_idx , nin , nlatch , nout , nand = ids
100125 if nin + nlatch + nand > max_idx :
101126 raise ValueError ("Sum of claimed indices greater than max." )
102127
103128 state .header = Header (
129+ binary_mode = binary_mode ,
104130 max_var_index = max_idx ,
105131 num_inputs = nin ,
106132 num_latches = nlatch ,
@@ -116,21 +142,38 @@ def parse_header(state, line) -> bool:
116142IO_PATTERN = re .compile (r"(\d+)\s*\n" )
117143
118144
119- def parse_input (state , line ) -> bool :
120- match = IO_PATTERN .match (line )
121-
122- if match is None or state .remaining_inputs <= 0 :
123- return False
124- lit = int (line )
145+ def _add_input (state , lit ):
125146 state .inputs .append (lit )
126147 state .nodes [lit ] = set ()
127- return True
128148
129149
130- def parse_output (state , line ) -> bool :
150+ def parse_input (state , stream ) -> bool :
151+ if state .remaining_inputs <= 0 :
152+ return False
153+
154+ if state .header .binary_mode :
155+ for lit in range (2 , 2 * (state .header .num_inputs + 1 ), 2 ):
156+ _add_input (state , lit )
157+ return False
158+
159+ line = _consume_stream (stream , '\n ' )
131160 match = IO_PATTERN .match (line )
132- if match is None or state .remaining_outputs <= 0 :
161+
162+ if match is None :
163+ raise ValueError (f"Expecting an input: { line } " )
164+
165+ _add_input (state , int (line ))
166+ return True
167+
168+
169+ def parse_output (state , stream ) -> bool :
170+ if state .remaining_outputs <= 0 :
133171 return False
172+
173+ line = _consume_stream (stream , '\n ' )
174+ match = IO_PATTERN .match (line )
175+ if match is None :
176+ raise ValueError (f"Expecting an output: { line } " )
134177 lit = int (line )
135178 state .outputs .append (lit )
136179 if lit & 1 :
@@ -139,17 +182,28 @@ def parse_output(state, line) -> bool:
139182
140183
141184LATCH_PATTERN = re .compile (r"(\d+) (\d+)(?: (\d+))?\n" )
185+ LATCH_PATTERN_BINARY = re .compile (r"(\d+)(?: (\d+))?\n" )
142186
143187
144- def parse_latch (state , line ) -> bool :
188+ def parse_latch (state , stream ) -> bool :
145189 if state .remaining_latches <= 0 :
146190 return False
147191
148- match = LATCH_PATTERN .match (line )
149- if match is None :
150- raise ValueError ("Expecting a latch: {line}" )
192+ line = _consume_stream (stream , '\n ' )
193+
194+ if state .header .binary_mode :
195+ match = LATCH_PATTERN_BINARY .match (line )
196+ if match is None :
197+ raise ValueError (f"Expecting a latch: { line } " )
198+ idx = state .header .num_inputs + len (state .latches ) + 1
199+ lit = 2 * idx
200+ elems = (lit ,) + match .groups ()
201+ else :
202+ match = LATCH_PATTERN .match (line )
203+ if match is None :
204+ raise ValueError (f"Expecting a latch: { line } " )
205+ elems = match .groups ()
151206
152- elems = match .groups ()
153207 if elems [2 ] is None :
154208 elems = elems [:2 ] + (0 ,)
155209 elems = fn .lmap (int , elems )
@@ -165,30 +219,69 @@ def parse_latch(state, line) -> bool:
165219AND_PATTERN = re .compile (r"(\d+) (\d+) (\d+)\s*\n" )
166220
167221
168- def parse_and (state , line ) -> bool :
169- if state .header .num_ands <= 0 :
170- return False
171-
172- match = AND_PATTERN .match (line )
173- if match is None :
174- return False
175-
176- elems = fn .lmap (int , match .groups ())
177- state .header .num_ands -= 1
178- deps = set (elems [1 :])
179- state .nodes [elems [0 ]] = deps
222+ def _read_delta (data ):
223+ ch = next (data )
224+ i = 0
225+ delta = 0
226+ while (ch & 0x80 ) != 0 :
227+ if i == 5 :
228+ raise ValueError ("Invalid byte in delta encoding" )
229+ delta |= (ch & 0x7f ) << (7 * i )
230+ i += 1
231+ ch = next (data )
232+ if i == 5 and ch >= 8 :
233+ raise ValueError ("Invalid byte in delta encoding" )
234+
235+ delta |= ch << (7 * i )
236+ return delta
237+
238+
239+ def _add_and (state , elems ):
240+ lhs , rhs0 , rhs1 = fn .lmap (int , elems )
241+ state .ands .append (And (lhs , rhs0 , rhs1 ))
242+ deps = {rhs0 , rhs1 }
243+ state .nodes [lhs ] = deps
180244 for dep in deps :
181245 if dep & 1 :
182246 state .nodes [dep ] = {dep ^ 1 }
247+
248+
249+ def parse_and (state , stream ) -> bool :
250+ if state .remaining_ands <= 0 :
251+ return False
252+
253+ if state .header .binary_mode :
254+ idx = state .header .num_inputs + state .header .num_latches + len (state .ands ) + 1
255+ lhs = 2 * idx
256+ delta = _read_delta (stream )
257+ if delta > lhs :
258+ raise ValueError (f"Invalid lhs { lhs } or delta { delta } " )
259+ rhs0 = lhs - delta
260+ delta = _read_delta (stream )
261+ if delta > rhs0 :
262+ raise ValueError (f"Invalid rhs0 { rhs0 } or delta { delta } " )
263+ rhs1 = rhs0 - delta
264+ else :
265+ line = _consume_stream (stream , '\n ' )
266+ match = AND_PATTERN .match (line )
267+ if match is None :
268+ raise ValueError (f"Expecting an and: { line } " )
269+ lhs , rhs0 , rhs1 = match .groups ()
270+
271+ _add_and (state , (lhs , rhs0 , rhs1 ))
183272 return True
184273
185274
186275SYM_PATTERN = re .compile (r"([ilo])(\d+) (.*)\s*\n" )
187276
188277
189- def parse_symbol (state , line ) -> bool :
278+ def parse_symbol (state , stream ) -> bool :
279+ line = _consume_stream (stream , '\n ' )
190280 match = SYM_PATTERN .match (line )
191281 if match is None :
282+ # We might have consumed the 'c' starting the comments section
283+ if line .rstrip () == 'c' :
284+ state .comments = []
192285 return False
193286
194287 kind , idx , name = match .groups ()
@@ -202,7 +295,8 @@ def parse_symbol(state, line) -> bool:
202295 return True
203296
204297
205- def parse_comment (state , line ) -> bool :
298+ def parse_comment (state , stream ) -> bool :
299+ line = _consume_stream (stream , '\n ' )
206300 if state .comments is not None :
207301 state .comments .append (line .rstrip ())
208302 elif line .rstrip () == 'c' :
@@ -227,25 +321,30 @@ def finish_table(table, keys):
227321 return {table [i ]: key for i , key in enumerate (keys )}
228322
229323
230- def parse (lines , to_aig : bool = True ):
231- if isinstance (lines , str ):
232- lines = io .StringIO (lines )
324+ def parse (stream ):
325+ if isinstance (stream , list ):
326+ stream = '' .join (stream )
327+ if isinstance (stream , str ):
328+ stream = bytes (stream , 'ascii' )
329+ stream = iter (stream )
233330
234331 state = State ()
235332 parsers = parse_seq ()
236333 parser = next (parsers )
237334
238- for i , line in enumerate (lines ):
239- while not parser (state , line ):
335+ i = 0
336+ while stream .__length_hint__ () > 0 :
337+ i += 1
338+ while not parser (state , stream ):
240339 parser = next (parsers , None )
241340
242341 if parser is None :
243- raise ValueError (NOT_DONE_PARSING_ERROR .format (i + 1 , state ))
342+ raise ValueError (NOT_DONE_PARSING_ERROR .format (i , state ))
244343
245344 if parser not in (parse_header , parse_output , parse_comment , parse_symbol ):
246345 raise ValueError (DONE_PARSING_ERROR .format (state ))
247346
248- assert state .header . num_ands == 0
347+ assert state .remaining_ands == 0
249348 assert state .remaining_inputs == 0
250349 assert state .remaining_outputs == 0
251350 assert state .remaining_latches == 0
@@ -260,6 +359,7 @@ def parse(lines, to_aig: bool = True):
260359
261360 # Create expression DAG.
262361 latch_ids = {latch .id : name for name , latch in latches .items ()}
362+ and_ids = {and_ .lhs : and_ for and_ in state .ands }
263363 lit2expr = {0 : A .aig .ConstFalse ()}
264364 for lit in toposort_flatten (state .nodes ):
265365 if lit == 0 :
@@ -272,6 +372,7 @@ def parse(lines, to_aig: bool = True):
272372 elif lit & 1 :
273373 lit2expr [lit ] = A .aig .Inverter (lit2expr [lit & - 2 ])
274374 else :
375+ assert lit in and_ids
275376 nodes = [lit2expr [lit2 ] for lit2 in state .nodes [lit ]]
276377 lit2expr [lit ] = reduce (A .aig .AndGate , nodes )
277378
@@ -284,9 +385,9 @@ def parse(lines, to_aig: bool = True):
284385 )
285386
286387
287- def load (path : str , to_aig : bool = True ):
288- with open (path , 'r ' ) as f :
289- return parse ('' . join ( f . readlines ()), to_aig = to_aig )
388+ def load (path : str ):
389+ with open (path , 'rb ' ) as f :
390+ return parse (f . read () )
290391
291392
292393__all__ = ['load' , 'parse' ]
0 commit comments