mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
Derive an industrial-strength conjoin() via cross-recursion loop unrolling,
and fiddle the conjoin tests to exercise all the new possible paths.
This commit is contained in:
parent
4efb6e9643
commit
c468fd28b6
1 changed files with 83 additions and 9 deletions
|
@ -776,6 +776,62 @@ def conjoin(gs):
|
||||||
for x in gen(0):
|
for x in gen(0):
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
|
# That works fine, but recursing a level and checking i against len(gs) for
|
||||||
|
# each item produced is inefficient. By doing manual loop unrolling across
|
||||||
|
# generator boundaries, it's possible to eliminate most of that overhead.
|
||||||
|
# This isn't worth the bother *in general* for generators, but conjoin() is
|
||||||
|
# a core building block for some CPU-intensive generator applications.
|
||||||
|
|
||||||
|
def conjoin(gs):
|
||||||
|
|
||||||
|
n = len(gs)
|
||||||
|
values = [None] * n
|
||||||
|
|
||||||
|
# Do one loop nest at time recursively, until the # of loop nests
|
||||||
|
# remaining is divisible by 3.
|
||||||
|
|
||||||
|
def gen(i, values=values):
|
||||||
|
if i >= n:
|
||||||
|
yield values
|
||||||
|
|
||||||
|
elif (n-i) % 3:
|
||||||
|
ip1 = i+1
|
||||||
|
for values[i] in gs[i]():
|
||||||
|
for x in gen(ip1):
|
||||||
|
yield x
|
||||||
|
|
||||||
|
else:
|
||||||
|
for x in _gen3(i):
|
||||||
|
yield x
|
||||||
|
|
||||||
|
# Do three loop nests at a time, recursing only if at least three more
|
||||||
|
# remain. Don't call directly: this is an internal optimization for
|
||||||
|
# gen's use.
|
||||||
|
|
||||||
|
def _gen3(i, values=values):
|
||||||
|
assert i < n and (n-i) % 3 == 0
|
||||||
|
ip1, ip2, ip3 = i+1, i+2, i+3
|
||||||
|
g, g1, g2 = gs[i : ip3]
|
||||||
|
|
||||||
|
if ip3 >= n:
|
||||||
|
# These are the last three, so we can yield values directly.
|
||||||
|
for values[i] in g():
|
||||||
|
for values[ip1] in g1():
|
||||||
|
for values[ip2] in g2():
|
||||||
|
yield values
|
||||||
|
|
||||||
|
else:
|
||||||
|
# At least 6 loop nests remain; peel off 3 and recurse for the
|
||||||
|
# rest.
|
||||||
|
for values[i] in g():
|
||||||
|
for values[ip1] in g1():
|
||||||
|
for values[ip2] in g2():
|
||||||
|
for x in _gen3(ip3):
|
||||||
|
yield x
|
||||||
|
|
||||||
|
for x in gen(0):
|
||||||
|
yield x
|
||||||
|
|
||||||
# A conjoin-based N-Queens solver.
|
# A conjoin-based N-Queens solver.
|
||||||
|
|
||||||
class Queens:
|
class Queens:
|
||||||
|
@ -804,11 +860,10 @@ class Queens:
|
||||||
def rowgen(rowuses=rowuses):
|
def rowgen(rowuses=rowuses):
|
||||||
for j in rangen:
|
for j in rangen:
|
||||||
uses = rowuses[j]
|
uses = rowuses[j]
|
||||||
if uses & self.used:
|
if uses & self.used == 0:
|
||||||
continue
|
self.used |= uses
|
||||||
self.used |= uses
|
yield j
|
||||||
yield j
|
self.used &= ~uses
|
||||||
self.used &= ~uses
|
|
||||||
|
|
||||||
self.rowgenerators.append(rowgen)
|
self.rowgenerators.append(rowgen)
|
||||||
|
|
||||||
|
@ -834,10 +889,7 @@ conjoin_tests = """
|
||||||
Generate the 3-bit binary numbers in order. This illustrates dumbest-
|
Generate the 3-bit binary numbers in order. This illustrates dumbest-
|
||||||
possible use of conjoin, just to generate the full cross-product.
|
possible use of conjoin, just to generate the full cross-product.
|
||||||
|
|
||||||
>>> def g():
|
>>> for c in conjoin([lambda: (0, 1)] * 3):
|
||||||
... return [0, 1]
|
|
||||||
|
|
||||||
>>> for c in conjoin([g] * 3):
|
|
||||||
... print c
|
... print c
|
||||||
[0, 0, 0]
|
[0, 0, 0]
|
||||||
[0, 0, 1]
|
[0, 0, 1]
|
||||||
|
@ -848,6 +900,28 @@ possible use of conjoin, just to generate the full cross-product.
|
||||||
[1, 1, 0]
|
[1, 1, 0]
|
||||||
[1, 1, 1]
|
[1, 1, 1]
|
||||||
|
|
||||||
|
For efficiency in typical backtracking apps, conjoin() yields the same list
|
||||||
|
object each time. So if you want to save away a full account of its
|
||||||
|
generated sequence, you need to copy its results.
|
||||||
|
|
||||||
|
>>> def gencopy(iterator):
|
||||||
|
... for x in iterator:
|
||||||
|
... yield x[:]
|
||||||
|
|
||||||
|
>>> for n in range(10):
|
||||||
|
... all = list(gencopy(conjoin([lambda: (0, 1)] * n)))
|
||||||
|
... print n, len(all), all[0] == [0] * n, all[-1] == [1] * n
|
||||||
|
0 1 1 1
|
||||||
|
1 2 1 1
|
||||||
|
2 4 1 1
|
||||||
|
3 8 1 1
|
||||||
|
4 16 1 1
|
||||||
|
5 32 1 1
|
||||||
|
6 64 1 1
|
||||||
|
7 128 1 1
|
||||||
|
8 256 1 1
|
||||||
|
9 512 1 1
|
||||||
|
|
||||||
And run an 8-queens solver.
|
And run an 8-queens solver.
|
||||||
|
|
||||||
>>> q = Queens(8)
|
>>> q = Queens(8)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue