実現したいこと
今こちらのサイトで、「⽂書の「あいまい検索」機能をつくる」を学習しております。
https://axross-recipe.com/recipes/110
その中で類似度の高い記事を取得する関数を作り、
タプルの0番目が記事のindexで、1番目がコサイン類似度の結果を出したいと考えています。
ご教授いただければと思います。
発生している問題・分からないこと
以下のエラーメッセージが表示されました。
エラーメッセージ
error
1ValueError Traceback (most recent call last) 2Cell In[288], line 3 3 1 search_words = ["ビジネスマン"] 4 2 target_vec = to_vec(search_words) 5----> 3 res = get_similaries(target_vec, vecs_with_idx, topn=10) 6 4 print(res) 7 8Cell In[287], line 3, in get_similaries(target_vec, vecs_with_idx, topn) 9 1 def get_similaries(target_vec, vecs_with_idx, topn=10): 10 2 sim_list = [(idx, cos_sim(target_vec, v)) for idx, v in vecs_with_idx] 11----> 3 result = sorted(sim_list, key=lambda t: t[1], reverse=True) 12 4 return result[:topn] 13 14ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 15
該当のソースコード
pip install gensim from gensim.models.fasttext import FastText model = FastText(size=300) from tqdm.notebook import tqdm sentences = [tokenizer.tokenize(normalize(read_doc(p))) for p in tqdm(doc_paths)] model.build_vocab(sentences=sentences) model.train(sentences=sentences, total_examples=len(sentences), epochs=30) vector = model.wv["猫"] print("次元:", vector.shape) pprint.pprint(vector, compact=True) import numpy as np def to_vec(sentence): return np.mean([model.wv[w] for w in sentence], axis=0) doc_vec = to_vec(sentences[0]) pprint.pprint(doc_vec, compact=True) def cos_sim(v1, v2): return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) v1 = [1,0] # → v2 = [-1,0] # ← (v1と逆向き) v3 = [0, 1] # ↑ (v1と直角) v4 = [0.8,0.2] # v1と向きが近い print("v1とv2の類似度:", cos_sim(v1,v2)) print("v1とv3の類似度:", cos_sim(v1,v3)) print("v1とv4の類似度:", cos_sim(v1,v4)) vecs_with_idx = [(idx, to_vec(s)) for idx, s in enumerate(sentences)] def get_similaries(target_vec, vecs_with_idx, topn=10): sim_list = [(idx, cos_sim(target_vec, v)) for idx, v in vecs_with_idx] result = sorted(sim_list, key=lambda t: t[1], reverse=True) return result[:topn] search_words = ["ビジネスマン"] target_vec = to_vec(search_words) res = get_similaries(target_vec, vecs_with_idx, topn=10) print(res)
試したこと・調べたこと
上記の詳細・結果
a.any() or a.all()を入れるということはわかりましたが、どこに記入してよいかわからず悩んでいます。
補足
特になし
0 コメント