|
1 | 1 | from collections import defaultdict |
2 | | -from itertools import chain |
3 | 2 | from functools import reduce |
4 | | -from typing import Tuple, FrozenSet, NamedTuple, Union, Mapping, List |
| 3 | +from typing import Tuple, FrozenSet, NamedTuple, Union |
5 | 4 |
|
6 | 5 | import funcy as fn |
7 | 6 | from toposort import toposort |
8 | 7 |
|
9 | 8 | from aiger import common as cmn |
| 9 | +from aiger import parser |
10 | 10 |
|
11 | 11 |
|
12 | 12 | def timed_name(name, time): |
@@ -73,6 +73,8 @@ class AIG(NamedTuple): |
73 | 73 | latch2init: FrozenSet[Tuple[str, Node]] = frozenset() |
74 | 74 | comments: Tuple[str] = () |
75 | 75 |
|
| 76 | + _to_aag = parser.aig2aag |
| 77 | + |
76 | 78 | def __repr__(self): |
77 | 79 | return repr(self._to_aag()) |
78 | 80 |
|
@@ -229,34 +231,6 @@ def _unroll(): |
229 | 231 |
|
230 | 232 | return unrolled |
231 | 233 |
|
232 | | - def _to_aag(self): |
233 | | - aag, max_idx, lit_map = _to_aag( |
234 | | - self.cones | self.latch_cones, |
235 | | - AAG({}, {}, {}, [], self.comments), |
236 | | - ) |
237 | | - |
238 | | - # Check that all inputs have a lit. |
239 | | - for name in filter(lambda x: x not in aag.inputs, self.inputs): |
240 | | - aag.inputs[name] = lit_map[name] = 2 * max_idx |
241 | | - max_idx += 1 |
242 | | - |
243 | | - # Update cone maps. |
244 | | - aag.outputs.update({k: lit_map[cone] for k, cone in self.node_map}) |
245 | | - latch2init = dict(self.latch2init) |
246 | | - for name, cone in self.latch_map: |
247 | | - latch = LatchIn(name) |
248 | | - if latch not in lit_map: |
249 | | - lit = lit_map[latch] = 2 * max_idx |
250 | | - max_idx += 1 |
251 | | - else: |
252 | | - lit = lit_map[latch] |
253 | | - |
254 | | - init = int(latch2init[name]) |
255 | | - ilit = lit_map[cone] |
256 | | - aag.latches[name] = lit, ilit, init |
257 | | - |
258 | | - return aag |
259 | | - |
260 | 234 | def write(self, path): |
261 | 235 | with open(path, 'w') as f: |
262 | 236 | f.write(repr(self)) |
@@ -293,181 +267,6 @@ def _mod(node): |
293 | 267 | ) |
294 | 268 |
|
295 | 269 |
|
296 | | -def _to_idx(lit): |
297 | | - """AAG format uses least significant bit to encode an inverter. |
298 | | - The index is thus the interal literal shifted by one bit.""" |
299 | | - return lit >> 1 |
300 | | - |
301 | | - |
302 | | -def _polarity(i): |
303 | | - return Inverter if i & 1 == 1 else lambda x: x |
304 | | - |
305 | | - |
306 | | -class Header(NamedTuple): |
307 | | - max_var_index: int |
308 | | - num_inputs: int |
309 | | - num_latches: int |
310 | | - num_outputs: int |
311 | | - num_ands: int |
312 | | - |
313 | | - |
314 | | -class AAG(NamedTuple): |
315 | | - inputs: Mapping[str, int] |
316 | | - latches: Mapping[str, Tuple[int]] |
317 | | - outputs: Mapping[str, int] |
318 | | - gates: List[Tuple[int]] |
319 | | - comments: Tuple[str] |
320 | | - |
321 | | - @property |
322 | | - def header(self): |
323 | | - literals = chain( |
324 | | - self.inputs.values(), |
325 | | - self.outputs.values(), |
326 | | - fn.pluck(0, self.gates), fn.pluck(0, self.latches.values()) |
327 | | - ) |
328 | | - max_idx = max(map(_to_idx, literals), default=0) |
329 | | - return Header(max_idx, *map(len, self[:-1])) |
330 | | - |
331 | | - def __repr__(self): |
332 | | - if self.inputs: |
333 | | - input_names, input_lits = zip(*list(self.inputs.items())) |
334 | | - if self.outputs: |
335 | | - output_names, output_lits = zip(*list(self.outputs.items())) |
336 | | - if self.latches: |
337 | | - latch_names, latch_lits = zip(*list(self.latches.items())) |
338 | | - |
339 | | - out = f"aag " + " ".join(map(str, self.header)) + '\n' |
340 | | - if self.inputs: |
341 | | - out += '\n'.join(map(str, input_lits)) + '\n' |
342 | | - if self.latches: |
343 | | - out += '\n'.join([' '.join(map(str, xs)) |
344 | | - for xs in latch_lits]) + '\n' |
345 | | - if self.outputs: |
346 | | - out += '\n'.join(map(str, output_lits)) + '\n' |
347 | | - |
348 | | - if self.gates: |
349 | | - out += '\n'.join([' '.join(map(str, xs)) |
350 | | - for xs in self.gates]) + '\n' |
351 | | - if self.inputs: |
352 | | - out += '\n'.join( |
353 | | - f"i{idx} {name}" for idx, name in enumerate(input_names) |
354 | | - ) + '\n' |
355 | | - if self.outputs: |
356 | | - out += '\n'.join( |
357 | | - f"o{idx} {name}" for idx, name in enumerate(output_names) |
358 | | - ) + '\n' |
359 | | - |
360 | | - if self.latches: |
361 | | - out += '\n'.join( |
362 | | - f"l{idx} {name}" for idx, name in enumerate(latch_names) |
363 | | - ) + '\n' |
364 | | - |
365 | | - if self.comments: |
366 | | - out += 'c\n' + '\n'.join(self.comments) |
367 | | - if out[-1] != '\n': |
368 | | - out += '\n' |
369 | | - return out |
370 | | - |
371 | | - def _to_aig(self): |
372 | | - gate_order, latch_order = self.eval_order_and_gate_lookup |
373 | | - |
374 | | - lookup = fn.merge( |
375 | | - {0: ConstFalse()}, |
376 | | - {_to_idx(l): Input(n) for n, l in self.inputs.items()}, |
377 | | - { |
378 | | - _to_idx(l): LatchIn(n) |
379 | | - for n, (l, _, init) in self.latches.items() |
380 | | - }, |
381 | | - ) |
382 | | - latches = set() |
383 | | - and_dependencies = {i: (l, r) for i, l, r in self.gates} |
384 | | - for gate in fn.cat(gate_order): |
385 | | - if _to_idx(gate) in lookup: |
386 | | - continue |
387 | | - |
388 | | - inputs = and_dependencies[gate] |
389 | | - sources = [_polarity(i)(lookup[_to_idx(i)]) for i in inputs] |
390 | | - lookup[_to_idx(gate)] = AndGate(*sources) |
391 | | - |
392 | | - latch_dependencies = { |
393 | | - i: (n, dep) for n, (i, dep, _) in self.latches.items() |
394 | | - } |
395 | | - for gate in fn.cat(latch_order): |
396 | | - assert _to_idx(gate) in lookup |
397 | | - if not isinstance(lookup[_to_idx(gate)], LatchIn): |
398 | | - continue |
399 | | - |
400 | | - name, dep = latch_dependencies[gate] |
401 | | - source = _polarity(dep)(lookup[_to_idx(dep)]) |
402 | | - latches.add((name, source)) |
403 | | - |
404 | | - def get_output(v): |
405 | | - idx = _to_idx(v) |
406 | | - return _polarity(v)(lookup[idx]) |
407 | | - |
408 | | - top_level = ((k, get_output(v)) for k, v in self.outputs.items()) |
409 | | - return AIG( |
410 | | - inputs=frozenset(self.inputs), |
411 | | - node_map=frozenset(top_level), |
412 | | - latch_map=frozenset(latches), |
413 | | - latch2init=frozenset( |
414 | | - (n, bool(init)) for n, (_, _, init) in self.latches.items() |
415 | | - ), |
416 | | - comments=self.comments |
417 | | - ) |
418 | | - |
419 | | - @property |
420 | | - def eval_order_and_gate_lookup(self): |
421 | | - deps = fn.merge( |
422 | | - {a & -2: {b & -2, c & -2} for a, b, c in self.gates}, |
423 | | - {a & -2: set() for a in self.inputs.values()}, |
424 | | - {a & -2: set() for a, _, _ in self.latches.values()}, # LatchIn |
425 | | - ) |
426 | | - latch_deps = {a & -2: {b & -2} for a, b, _ in self.latches.values()} |
427 | | - return list(toposort(deps)), list(toposort(latch_deps)) |
428 | | - |
429 | | - |
430 | | -def _to_aag(gates, aag: AAG = None, *, max_idx=1, lit_map=None): |
431 | | - if lit_map is None: |
432 | | - lit_map = {} |
433 | | - |
434 | | - if not gates: |
435 | | - return aag, max_idx, lit_map |
436 | | - |
437 | | - # Recurse to update get aag for subtrees. |
438 | | - for c in fn.mapcat(lambda g: g.children, gates): |
439 | | - if c in lit_map: |
440 | | - continue |
441 | | - aag, max_idx, lit_map = _to_aag( |
442 | | - [c], aag, max_idx=max_idx, lit_map=lit_map |
443 | | - ) |
444 | | - |
445 | | - # Update aag with current level. |
446 | | - for gate in gates: |
447 | | - if gate in lit_map: |
448 | | - continue |
449 | | - |
450 | | - if isinstance(gate, Inverter): |
451 | | - input_lit = lit_map[gate.input] |
452 | | - lit_map[gate] = (input_lit & -2) | (1 ^ (input_lit & 1)) |
453 | | - continue |
454 | | - elif isinstance(gate, ConstFalse): |
455 | | - lit_map[gate] = 0 |
456 | | - continue |
457 | | - |
458 | | - # Must be And, Latch, or Input |
459 | | - lit_map[gate] = 2 * max_idx |
460 | | - max_idx += 1 |
461 | | - if isinstance(gate, AndGate): |
462 | | - encoded = tuple(map(lit_map.get, (gate, gate.left, gate.right))) |
463 | | - aag.gates.append(encoded) |
464 | | - |
465 | | - elif isinstance(gate, Input): |
466 | | - aag.inputs[gate.name] = lit_map[gate] |
467 | | - |
468 | | - return aag, max_idx, lit_map |
469 | | - |
470 | | - |
471 | 270 | def _dependency_graph(nodes): |
472 | 271 | queue, deps, visited = list(nodes), defaultdict(set), set() |
473 | 272 | while queue: |
|
0 commit comments