There might be a faster way to do this, but we can work it out pretty well with numpy
and itertools
. To get started, for each unique label k we need to partition source_list
into the entries that correspond to label k (the "positive" part) and those that don't (the "negative" part).
>>> source_list = [1,2,3,4,5,6,7,8] #... size N
>>> labels_list = [1,1,0,2,1,3,6,2] #... size N
>>> for key in set(labels_list):
... positive = [s for i,s in enumerate(source_list) if labels_list[i] == key]
... negative = [s for i,s in enumerate(source_list) if labels_list[i] != key]
... print(key, positive, negative)
...
0 [3] [1, 2, 4, 5, 6, 7, 8]
1 [1, 2, 5] [3, 4, 6, 7, 8]
2 [4, 8] [1, 2, 3, 5, 6, 7]
3 [6] [1, 2, 3, 4, 5, 7, 8]
6 [7] [1, 2, 3, 4, 5, 6, 8]
We can speed this up a little with NumPy and masking, which essentially lets us avoid doing the list comprehension twice.
>>> import numpy as np
>>> source_array = np.array([1,2,3,4,5,6,7,8])
>>> labels_array = np.array([1,1,0,2,1,3,6,2])
>>> for key in np.unique(labels_array):
... mask = (labels_array == key)
... positive = source_array[mask]
... negative = source_array[~mask] # "not" mask
... print(key, positive, negative)
...
0 [3] [1 2 4 5 6 7 8]
1 [1 2 5] [3 4 6 7 8]
2 [4 8] [1 2 3 5 6 7]
3 [6] [1 2 3 4 5 7 8]
6 [7] [1 2 3 4 5 6 8]
Side note: If <anchor>
and <positive>
aren't allowed to represent the same entry in sources_list
, then we need each positive group to have at least two members in it. That probably happens in your actual data, but we'll just skip cases with a single positive for this demo.
Now comes the itertools
. We want all unique permutations of 2 entries from the positive list, with another entry from the negative list. Since the results must be unique, we can cut down on complexity slightly be removing duplicates from each positive and negative list.
>>> import itertools
>>> import numpy as np
>>> source_array = np.array([1,2,3,4,5,6,7,8])
>>> labels_array = np.array([1,1,0,2,1,3,6,2])
>>> for key in np.unique(labels_array):
... mask = (labels_array == key)
... positive = np.unique(source_array[mask]) # No duplicates
... if len(positive) < 2: # Skip singleton positives
... continue
... negative = np.unique(source_array[~mask]) # No duplicates
... print(key, positive, negative)
... for ((a,b),c) in itertools.product(itertools.permutations(positive, 2),
... negative):
... print(a,b,c)
... print()
Output:
1 [1 2 5] [3 4 6 7 8]
1 2 3
1 2 4
1 2 6
1 2 7
1 2 8
1 5 3
1 5 4
1 5 6
1 5 7
1 5 8
2 1 3
2 1 4
2 1 6
2 1 7
2 1 8
2 5 3
2 5 4
2 5 6
2 5 7
2 5 8
5 1 3
5 1 4
5 1 6
5 1 7
5 1 8
5 2 3
5 2 4
5 2 6
5 2 7
5 2 8
2 [4 8] [1 2 3 5 6 7]
4 8 1
4 8 2
4 8 3
4 8 5
4 8 6
4 8 7
8 4 1
8 4 2
8 4 3
8 4 5
8 4 6
8 4 7