本文主要是介绍【Faiss】基础索引类型(六),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
基础索引类型
数据准备
import numpy as np
d = 512 #维数
n_data = 2000
np.random.seed(0)
data = []
mu = 3
sigma = 0.1
for i in range(n_data):data.append(np.random.normal(mu, sigma, d))
data = np.array(data).astype('float32')#query
query = []
n_query = 10
np.random.seed(12)
query = []
for i in range(n_query):query.append(np.random.normal(mu, sigma, d))
query = np.array(query).astype('float32')
#导入faiss
import sys
sys.path.append('/home/maliqi/faiss/python/')
import faiss
1.精确搜索(Exact Search for L2)
一种暴力搜索方法,遍历数据库中的每一个向量与查询向量对比。
index = faiss.IndexFlatL2(d)
# index = faiss.index_factory(d, "Flat") #两种定义方式
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.61838 8.782156 8.782816 8.832029 8.837633 8.848496 8.8979788.916636 8.919006 8.9374 ][9.033303 9.038907 9.091705 9.15584 9.164591 9.200112 9.2018849.220335 9.279477 9.312859 ][8.063818 8.211029 8.306456 8.373352 8.459253 8.459892 8.4985578.546464 8.555408 8.621426 ][8.193894 8.211956 8.34701 8.446963 8.45299 8.45486 8.4735728.50477 8.513636 8.530684 ][8.369624 8.549444 8.704066 8.736764 8.760082 8.777319 8.8313458.835486 8.858271 8.860058 ][8.299072 8.432398 8.434382 8.457374 8.539217 8.562359 8.5790338.618736 8.630861 8.643393 ][8.615004 8.615164 8.72604 8.730943 8.762621 8.796932 8.7970688.797365 8.813985 8.834726 ][8.377227 8.522776 8.711159 8.724562 8.745737 8.763846 8.7686028.7727995 8.786856 8.828224 ][8.342917 8.488056 8.655106 8.662771 8.701336 8.741287 8.7436088.770507 8.786264 8.849051 ][8.522164 8.575703 8.68462 8.767247 8.782909 8.850494 8.8837338.90369 8.909393 8.91768 ]]
2.精确搜索(Exact Search for Inner Product)
当数据库向量是标准化的,计算返回的distance就是余弦相似度。
index = faiss.IndexFlatIP(d)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[4621.749 4621.5464 4619.745 4619.381 4619.177 4618.0615 4617.1694617.0566 4617.0483 4616.631 ][4637.3975 4637.288 4635.368 4635.2446 4634.881 4633.608 4633.02154632.7637 4632.56 4632.373 ][4621.756 4621.4697 4619.7485 4619.5615 4619.424 4618.0186 4616.99274616.962 4616.901 4616.735 ][4623.6074 4623.5596 4621.3965 4621.158 4620.906 4619.838 4618.97564618.9126 4618.7695 4618.478 ][4625.553 4625.0645 4623.461 4623.196 4622.957 4621.337 4620.73734620.717 4620.5635 4620.2485][4628.489 4628.449 4626.491 4626.487 4625.6406 4624.6143 4624.294624. 4623.7524 4623.618 ][4637.7466 4637.338 4635.3047 4635.125 4634.748 4633.0137 4632.8644632.58 4632.3027 4632.2324][4630.472 4630.333 4628.264 4627.9375 4627.738 4626.8965 4625.8144625.7227 4625.4443 4625.091 ][4635.7715 4635.489 4633.6904 4633.568 4632.658 4631.463 4631.43074631.101 4630.99 4630.3066][4625.6753 4625.558 4623.454 4623.3926 4623.324 4622.2827 4621.77834621.1157 4620.905 4620.854 ]]
3.(Hierarchical Navigable Small World graph exploration)
返回近似结果。
index = faiss.IndexHNSWFlat(d,16)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.61838 8.832029 8.848496 8.897978 8.916636 8.9374 8.95978.962785 8.984709 8.998907 ][9.038907 9.164591 9.200112 9.201884 9.220335 9.312859 9.344349.344851 9.416974 9.421429 ][8.306456 8.373352 8.459253 8.546464 8.631898 8.63715 8.639178.713682 8.735945 8.7704735][8.193894 8.211956 8.34701 8.45486 8.473572 8.50477 8.5136368.530684 8.545482 8.617173 ][8.369624 8.760082 8.831345 8.858271 8.860058 8.862642 8.9369518.996922 8.998444 9.022133 ][8.299072 8.432398 8.434382 8.539217 8.562359 8.698317 8.7536728.768751 8.779131 8.780444 ][8.615004 8.615164 8.730943 8.797365 8.861536 8.885755 8.9118128.922768 8.942963 8.980488 ][8.377227 8.522776 8.711159 8.724562 8.745737 8.768602 8.77279958.786856 8.828224 8.879469 ][8.342917 8.488056 8.662771 8.741287 8.743608 8.770507 8.8572558.893716 8.932134 8.933593 ][8.575703 8.68462 8.850494 8.883733 8.90369 8.909393 8.917688.936615 8.961668 8.977329 ]]
4.倒排表搜索(Inverted file with exact post-verification)
快速入门部分介绍过。
nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.837633 9.122337 9.217627 9.362019 9.39345 9.396795 9.4015569.446939 9.52043 9.5279255][9.436286 9.636714 9.707813 9.714355 9.734249 9.809814 9.877229.960412 9.978079 9.982276 ][8.621426 8.658703 8.842339 8.862192 8.891519 8.937078 8.9727678.98658 9.007745 9.088661 ][8.211956 8.735372 8.747662 8.800873 8.917062 9.1208725 9.1788529.215968 9.2192 9.265095 ][8.858271 8.998444 9.041813 9.0883045 9.159481 9.169218 9.1879489.203735 9.204121 9.256811 ][8.434382 8.539217 8.630861 8.753672 8.768751 8.794859 8.8151658.817884 8.8404 8.848925 ][8.861536 8.878873 8.942963 8.944212 8.9446945 8.95914 8.9804889.051479 9.059914 9.081419 ][9.15522 9.423113 9.432117 9.465836 9.529045 9.554071 9.5562689.638275 9.656209 9.69151 ][8.743608 8.902418 9.065649 9.201052 9.223066 9.223073 9.2474149.269661 9.288244 9.291237 ][8.936615 9.077 9.152468 9.1537075 9.313195 9.314999 9.3731969.400535 9.434517 9.445862 ]]
5.LSH(Locality-Sensitive Hashing (binary flat index))
nbits = 2 * d
index = faiss.IndexLSH(d, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[ 8. 10. 10. 10. 10. 10. 10. 11. 11. 11.][ 7. 8. 9. 9. 9. 10. 10. 10. 10. 10.][ 7. 8. 8. 9. 9. 9. 9. 9. 9. 9.][ 9. 9. 10. 11. 12. 12. 12. 12. 12. 12.][ 6. 6. 6. 7. 7. 8. 8. 8. 8. 8.][ 8. 8. 8. 9. 9. 9. 9. 9. 10. 10.][ 6. 7. 8. 8. 9. 9. 9. 9. 9. 9.][ 9. 9. 9. 9. 9. 9. 9. 9. 9. 10.][ 7. 8. 8. 8. 8. 8. 8. 9. 9. 9.][ 9. 9. 9. 10. 10. 10. 10. 10. 10. 10.]]
6.SQ量化(Scalar quantizer (SQ) in flat mode)
index = faiss.IndexScalarQuantizer(d, 4)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.623227 8.777792 8.785317 8.828824 8.83549 8.845292 8.8968968.914818 8.922382 8.934983 ][9.028506 9.037546 9.099248 9.1526165 9.16542 9.19639 9.2004999.224975 9.274046 9.3053875][8.064029 8.21301 8.310526 8.376435 8.457833 8.462002 8.5010878.550647 8.556992 8.624525 ][8.19665 8.210531 8.346436 8.444769 8.452809 8.454114 8.47452458.496618 8.510042 8.525612 ][8.370452 8.547959 8.704323 8.733619 8.763926 8.776738 8.8295118.835644 8.857149 8.859046 ][8.29591 8.432422 8.435944 8.454732 8.542395 8.565367 8.5796838.621871 8.632034 8.644775 ][8.609016 8.612934 8.72663 8.734133 8.758857 8.797326 8.7979668.798654 8.815295 8.8382225][8.378947 8.521084 8.711153 8.726161 8.748383 8.759655 8.7682188.769182 8.792372 8.834644 ][8.340463 8.48951 8.659344 8.664954 8.702756 8.741513 8.7419418.768993 8.781276 8.852154 ][8.520282 8.574987 8.683459 8.769213 8.7820425 8.85128 8.8811188.906741 8.907756 8.924014 ]]
7.PQ量化(Product quantizer (PQ) in flat mode)
M = 8 #必须是d的因数
nbits = 6 #只能是8, 12, 16
index = faiss.IndexPQ(d, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[5.3184814 5.33667 5.3638916 5.366333 5.3704834 5.4000244 5.4046635.415283 5.425659 5.427246 ][5.6835938 5.686035 5.687134 5.7489014 5.76062 5.7731934 5.77661135.7875977 5.798828 5.7991943][4.902588 5.0057373 5.0323486 5.036255 5.045044 5.048828 5.04980475.0499268 5.072998 5.0737305][4.844116 4.850586 4.868042 4.8946533 4.8997803 4.8999023 4.9029544.909546 4.9210205 4.921875 ][5.279419 5.333252 5.3344727 5.3431396 5.35083 5.357422 5.3662115.3862305 5.38855 5.3936768][5.019409 5.048706 5.0942383 5.1052246 5.116455 5.157593 5.1594245.168457 5.171875 5.194092 ][5.0563965 5.0909424 5.1367188 5.1534424 5.1724854 5.199951 5.201055.2144775 5.214966 5.23938 ][5.16333 5.173706 5.2418213 5.265259 5.265869 5.274414 5.2913825.307495 5.309204 5.310425 ][5.1501465 5.2508545 5.291992 5.3186035 5.3205566 5.328369 5.3365485.3479004 5.35376 5.360962 ][5.2751465 5.2772217 5.279663 5.3304443 5.350708 5.3571777 5.36694345.373047 5.373413 5.382324 ]]
8.倒排表乘积量化(IVFADC (coarse quantizer+PQ on residuals))
M = 8
nbits = 4
nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[5.1985765 5.209732 5.233874 5.237282 5.2553835 5.262968 5.2704625.2895284 5.2908745 5.302353 ][5.5696826 5.5942397 5.611737 5.6186624 5.619787 5.643144 5.6460765.676093 5.682111 5.6982036][4.7446747 4.824335 4.834736 4.844829 4.850663 4.853364 4.8673934.873641 4.8785725 4.88787 ][4.783175 4.797909 4.8491716 4.85687 4.857151 4.8586845 4.8600584.866444 4.868099 4.885188 ][5.1260395 5.134188 5.1386065 5.141901 5.1756086 5.192538 5.19382675.1975694 5.199704 5.2012296][4.882325 4.900981 4.9040375 4.911916 4.916094 4.923492 4.9284334.928472 4.937878 4.95728 ][4.9729834 4.976016 4.984484 5.0074816 5.0200887 5.0217285 5.0294795.029899 5.0346465 5.0349855][5.1357193 5.147153 5.1525207 5.189519 5.217377 5.220489 5.23417665.239973 5.2411985 5.253551 ][5.0623484 5.087064 5.1075807 5.109309 5.110051 5.1330123 5.13877155.1431603 5.151037 5.1516275][5.12455 5.163775 5.1762547 5.185327 5.190364 5.19723 5.20991755.2115583 5.214532 5.2182474]]
cell-probe方法
为了加速索引过程,经常采用划分子类空间(如k-means)的方法,虽然这样无法保证最后返回的结果是完全正确的。先划分子类空间,再在部分子空间中搜索的方法,就是cell-probe方法。
具体流程为:
1)数据集空间被划分为n个部分,在k-means中,表现为n个类;
2)每个类中的向量保存在一个倒排表中,共有n个倒排表;
3)查询时,选中nprobe个倒排表;
4)将这几个倒排表中的向量与查询向量作对比。
在这种方法中,只需要排查数据库中的一部分向量,大约只有nprobe/n的数据,因为每个倒排表的长度并不一致(每个类中的向量个数不一定相等)。
cell-probe粗量化
在一些索引类型中,需要一个Flat index作为粗量化器,如IndexIVFFlat,在训练的时候会将类中心保存在Flat index中,在add和search阶段,会首先判定将其落入哪个类空间。在search阶段,nprobe参数需要调整以权衡检索精度与检索速度。
实验表明,对高维数据,需要维持比较高的nprobe数值才能保证精度。
与LSH的优劣
LSH也是一种cell-probe方法,与其相比,LSH有一下一点不足:
1)LSH需要大量的哈希方程,会带来额外的内存开销;
2)哈希函数不适合输入数据。
这篇关于【Faiss】基础索引类型(六)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!