Another option is to use the constructor
csr_matrix((data, (row_ind, col_ind)), [shape=(M, N)])
from scipy.sparse.csr_matrix where data
, row_ind
and col_ind
satisfy the
relationship a[row_ind[k], col_ind[k]] = data[k]
.
The trick is to generate row_ind
and col_ind
by iterating over the documents and creating a list of tuples (doc_id, word_id). data
would simply be a vector of ones of the same length.
Multiplying the docs-words matrix by its transpose would give you the co-occurences matrix.
Additionally, this is efficient in terms of both run times and memory usage, so it should also handle big corpuses.
import numpy as np
import itertools
from scipy.sparse import csr_matrix
def create_co_occurences_matrix(allowed_words, documents):
print(f"allowed_words:
{allowed_words}")
print(f"documents:
{documents}")
word_to_id = dict(zip(allowed_words, range(len(allowed_words))))
documents_as_ids = [np.sort([word_to_id[w] for w in doc if w in word_to_id]).astype('uint32') for doc in documents]
row_ind, col_ind = zip(*itertools.chain(*[[(i, w) for w in doc] for i, doc in enumerate(documents_as_ids)]))
data = np.ones(len(row_ind), dtype='uint32') # use unsigned int for better memory utilization
max_word_id = max(itertools.chain(*documents_as_ids)) + 1
docs_words_matrix = csr_matrix((data, (row_ind, col_ind)), shape=(len(documents_as_ids), max_word_id)) # efficient arithmetic operations with CSR * CSR
words_cooc_matrix = docs_words_matrix.T * docs_words_matrix # multiplying docs_words_matrix with its transpose matrix would generate the co-occurences matrix
words_cooc_matrix.setdiag(0)
print(f"words_cooc_matrix:
{words_cooc_matrix.todense()}")
return words_cooc_matrix, word_to_id
Run example:
allowed_words = ['A', 'B', 'C', 'D']
documents = [['A', 'B'], ['C', 'B', 'K'],['A', 'B', 'C', 'D', 'Z']]
words_cooc_matrix, word_to_id = create_co_occurences_matrix(allowed_words, documents)
Output:
allowed_words:
['A', 'B', 'C', 'D']
documents:
[['A', 'B'], ['C', 'B', 'K'], ['A', 'B', 'C', 'D', 'Z']]
words_cooc_matrix:
[[0 2 1 1]
[2 0 2 1]
[1 2 0 1]
[1 1 1 0]]