本文主要是介绍BGE 模型转 onnx,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
BGE 模型
下载地址:https://hf-mirror.com/BAAI/bge-small-zh-v1.5
from transformers import AutoTokenizer, AutoModel, AutoConfigBGE_MODEL_PATH = '.../bge_small'
tokenizer = AutoTokenizer.from_pretrained(BGE_MODEL_PATH)
config = AutoConfig.from_pretrained(BGE_MODEL_PATH)
bge_model = AutoModel.from_pretrained(BGE_MODEL_PATH)
bge_model.eval()
# 设置线程池的线程数
torch.set_num_threads(4)
with torch.no_grad():try:model_output = bge_model(**tokenizer(['我喜欢吃牛肉面,你喜欢吃什么'], padding=True, truncation=True, return_tensors='pt'))sentence_embeddings = model_output[0][:, 0]except RuntimeError as e:print(e)
sentence_embeddings
"""
tensor([[ 5.4689e-01, -4.4698e-01, 5.5987e-01, 5.0286e-01, 3.6766e-01,-1.9772e-01, -5.1068e-02, 4.6330e-02, -3.4700e-01, 3.0879e-01,3.6935e-01, -1.6736e+00, -2.0509e-01, -2.6857e-01, -2.6716e-02,3.5493e-01, 1.1516e-01, 1.2000e-02, -5.4882e-01, 1.3549e-01,-4.9214e-01, -7.0827e-01, -9.2700e-02, -7.1855e-02, 4.4418e-01,-9.2637e-02, 2.1580e-01, -7.2051e-01, -4.3432e-02, -6.6411e-01,-3.9080e-01, -1.6758e-01, 2.0448e-01, -1.3652e-01, -1.7541e-01,-3.1305e-01, -4.2215e-01, 5.6420e-01, 5.7210e-02, 4.4165e-01,9.7896e-02, -1.2325e-01, 3.8261e-01, 6.0310e-02, -1.1006e-01,-4.0533e-01, -2.5916e-01, 4.4360e-01, 1.2225e-01, 3.2870e-01,-4.1362e-02, 1.1426e-01, -8.8607e-02, -1.0882e-01, -4.2533e-01,-2.3075e-01, 7.4417e-02, -2.0983e-01, 3.3545e-01, 6.7026e-01,7.1340e-02, -3.7793e-01, -1.1361e-01, 2.1588e-01, 3.9505e-01,1.1894e-01, -7.2442e-02, -4.1545e-01, -9.0644e-01, 2.0836e-01,2.2234e-01, 3.6160e-01, 2.3963e-01, -2.9590e-01, -7.2153e-02,1.5813e-01, -5.2094e-01, -5.2676e-02, 2.4149e-02, -3.4751e-01,-2.8603e-01, 1.1151e-01, 4.3655e-02, 1.8419e-01, -6.1452e-01,-2.4722e-02, -4.1811e-01, 7.9452e-03, 6.7415e-01, -2.4954e-01,-4.6563e-01, -3.6022e-01, 1.7222e-01, -2.9508e-02, -2.4364e-01,-2.6364e-02, -3.3082e-01, 2.0640e-01, 1.9068e-01, -3.4625e-01,-4.9648e-01, -5.2063e-01, -4.0821e-01, 1.5907e-01, -5.4413e-01,2.1770e-01, -2.6624e-01, -2.7440e-01, 7.2711e-02, -9.0304e-03,6.4497e-01, 1.9281e-01, -6.5722e-01, 1.4551e-02, 3.5354e-01,-3.0756e-01, 3.2517e-01, 2.8923e-01, 4.7606e-02, 3.3110e-01,6.4861e-01, -8.4248e-02, 2.4017e-01, -3.0537e-01, -6.6136e-01,-2.6296e-02, -3.5374e-01, -2.5252e-01, -4.5524e-01, -7.2699e-02,-3.8576e-02, 5.4994e-01, 2.3332e-01, 3.6099e-01, 2.4436e-02,4.5180e-01, -2.2652e-01, 7.0281e-02, -5.3553e-01, 2.4426e-01,4.9931e-01, 1.9848e-01, -1.4271e-01, -1.9937e-01, 7.3104e-01,-3.1909e-01, -4.9863e-01, 5.5279e-01, 2.6611e-01, -2.7226e-02,8.5076e-01, -7.2496e-02, -3.3150e-02, -1.8021e-01, -4.2076e-01,8.2790e-02, 4.2276e-01, 6.8105e-01, -5.5905e-02, 3.8041e-01,1.9857e-02, -1.0535e-01, -1.9367e-01, 1.0984e-01, -2.0788e-01,-2.3731e-01, -3.9033e-01, 3.2689e-01, 2.9960e-01, -1.7020e-01,-3.1700e-02, -7.5043e-02, -4.5300e-01, 1.6537e-01, 7.3224e-02,3.7153e-01, 4.4157e-01, 1.1336e-01, -2.0618e-01, 7.1296e-02,4.0332e-02, 1.3099e-01, -1.1064e-01, 1.4121e-02, 3.1559e-01,9.0883e-02, -1.0909e-01, 7.3780e-02, -1.7433e-01, 3.2578e-01,-4.8197e-01, 6.5817e-02, -1.2913e-01, -1.2128e-02, -3.1212e-02,-3.0210e-01, -3.0553e-01, -2.9841e-01, 4.4311e-01, -2.1658e-01,5.6405e-01, 1.9815e-01, 6.1083e-01, 4.0385e-01, -2.2909e-01,1.5401e-01, -3.5981e-01, -3.2584e-01, 2.3588e-01, -2.4161e-01,-2.8531e-01, -2.5177e-01, -4.6212e-02, 3.9551e-01, 4.1296e-01,-3.2138e-01, -2.5293e-02, -1.1747e-02, 2.2911e-01, -4.4401e-02,1.3744e-01, -5.9783e-01, -4.4082e-02, -9.3257e-01, 2.4302e-01,1.3361e-01, 1.1007e-01, -1.1757e-01, -6.1506e-02, 1.0913e-01,-2.3069e-01, -2.0124e-01, -1.3337e-01, 2.5833e-01, 3.8252e-01,1.8271e-01, 2.1553e-01, -4.2312e-02, -4.0738e-02, -3.7732e-02,1.6577e-01, 7.2489e-01, 3.8472e-03, 7.1419e-02, -6.4604e-01,1.8439e-01, -1.9232e-01, 3.8212e-01, 4.9559e-02, 2.8516e-02,-2.9385e-01, -1.8578e-01, -4.5229e-01, -6.1040e-01, 1.9707e-01,1.6940e-01, 2.2794e-01, -1.2318e-01, -3.2788e-01, 2.8579e-01,-4.7024e-01, 5.2703e-02, 1.3036e-01, 1.2623e-01, 3.0678e-01,-2.3172e-01, 5.0764e-01, -1.1868e-01, 6.6122e-01, 2.4245e-01,-2.9181e-01, -1.8073e-02, 1.9517e-01, -2.3418e-01, -2.1227e-01,-6.9820e-02, -4.9085e-01, 1.1375e-01, 8.4655e-03, 3.7024e-01,-2.1939e-02, -3.2230e-02, 1.3403e-01, 9.7275e-02, -1.0690e-01,-1.9582e-01, 5.0396e-03, 2.9599e-01, -3.0861e-01, 3.0203e-01,1.4881e-01, 4.0052e-01, -5.9900e-01, 6.6792e-02, -8.3050e-01,-3.0983e-01, -2.5233e-01, -1.1867e-01, -2.4187e-01, 2.4415e-03,1.6608e-02, -1.0178e-02, -6.4497e-02, -5.6783e-01, -4.1931e-01,-2.5388e-01, 2.7708e-01, -4.3060e-01, -3.4531e-01, 3.2466e-01,-8.4626e-02, -2.4624e-02, 1.2812e-02, -5.5433e-01, -8.4344e-02,-3.1126e-01, 4.1195e-01, -4.9793e-01, -4.2517e-01, -4.7793e-01,-2.8242e-02, 1.5422e-01, 4.1673e-01, 5.9940e-02, 1.6173e-01,-3.0847e-01, 4.1962e-01, 2.4339e-01, -2.7365e-01, -1.6010e-01,-1.1309e-01, -1.6300e-01, 2.2023e-01, -2.9686e-01, 1.5195e-01,-1.2168e-01, -1.1041e-01, -4.8278e-01, -7.9084e-02, 3.5855e-01,-5.3409e-02, -3.3434e-02, 1.3354e-01, -9.3829e-02, -8.5430e-02,9.4407e-02, -4.8082e-02, 2.7377e-01, 3.2684e-01, 2.3176e-01,2.1731e-01, -5.4196e-01, 6.1811e-02, 6.8032e-02, -6.4076e-02,8.4046e-04, 2.9065e-01, -2.3864e-01, -1.8722e-01, 1.4242e-01,-6.5174e-01, -1.0239e-01, -3.6904e-01, -4.5531e-02, -1.0157e-01,4.8156e-02, -1.2844e-01, -4.0348e-01, 1.8065e-02, 4.8764e-01,-3.9452e-02, 4.1707e-01, 3.3127e-01, 1.0545e-01, 2.3099e-01,-8.2952e-02, 4.9434e-01, -1.0384e-01, -1.8039e-01, -2.8667e-01,-4.9395e-01, 4.1157e-01, 5.7075e-03, -2.5094e-01, 3.2738e-01,4.3705e-03, -1.6326e-01, 2.1795e-01, 3.0780e-01, 1.0144e-01,-1.0807e-01, -3.0062e-01, 1.1136e-01, -3.2089e-03, -6.9845e-02,-1.0194e+00, -2.1284e-02, 2.1786e-01, -5.5740e-02, 4.1513e-01,4.5104e-01, 1.0806e-01, 2.3549e-01, -1.0850e-01, -1.3185e-01,3.1303e-01, -5.9940e-02, -1.4784e-01, 1.2927e-01, 1.3251e-01,1.5369e+00, 3.1644e-01, -1.7503e-01, 1.7313e-01, 7.8891e-02,-1.1187e-01, -2.2653e-01, -2.1522e-01, 6.2684e-01, -4.2637e-02,2.4197e-01, 1.5407e-01, -1.8608e-01, 5.2919e-01, -3.4681e-01,-2.7972e-01, -1.8512e-01, 9.2144e-02, -2.3413e-01, -2.5999e-02,-1.8092e-02, -4.7184e-02, -2.9235e-01, -8.8475e-02, 2.0421e-01,5.6138e-02, -1.9283e-01, 1.7962e-01, 5.3208e-02, 4.1574e-01,-1.3559e-01, 2.4164e-01, 3.8262e-02, 2.6359e-01, 7.6174e-01,2.3113e-01, 4.2636e-02, 1.1490e-01, 6.7774e-02, -1.0746e-01,7.3171e-01, -2.3425e-01, -4.9999e-01, 2.2267e-01, -5.0996e-01,3.8451e-03, 1.0102e-01, -2.7670e-01, 2.3973e-02, -1.9262e-02,-1.3749e-01, -4.1251e-01, 5.4302e-01, 1.1319e-01, -3.4603e-01,-2.2249e-01, -6.9363e-02, -2.6070e-01, -9.0281e-03, 1.9883e-01,2.4935e-01, -1.2053e-01, -1.2499e-02, -4.9454e-01, -5.2119e-01,4.1544e-01, -3.0627e-01, 7.9147e-01, 5.2260e-01, -1.5444e-01,-2.3746e-01, -4.3107e-01, 2.1800e-01, 3.3687e-01, -8.0606e-02,3.0912e-02, 1.2320e-01, -4.4568e-01, 4.6880e-01, 6.0751e-02,-3.7896e-01, 7.1230e-02, 6.5275e-01, -4.2104e-01, 3.9520e-01,-5.3725e-01, -1.8266e-01, 1.4480e-01, -1.9651e-01, 4.6012e-01,-3.4502e-01, 1.4006e-01, -3.1940e-01, -1.1639e-01, -8.5920e-02,5.8876e-01, 2.2384e-01, -4.4084e-01, -4.6656e-01, -3.1166e-01,-4.5348e-01, 3.5287e-01]])
"""
导出 onnx 模型
# !pip install onnx onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple/# 导出 onnx 模型
import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManager
onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')
output_onnx_path = "bert.onnx"
model = bge_modeltorch.onnx.export(bge_model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,use_external_data_format=onnx_config.use_external_data_format(model.num_parameters()),enable_onnx_checker=True,opset_version=onnx_config.default_onnx_opset,
)
加载运行 onnx 模型
from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSessionoutput_onnx_path = "bert.onnx"
options = SessionOptions() # initialize session options
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# 设置线程数
options.intra_op_num_threads = 4# 这里的路径传上一节保存的onnx模型地址
session = InferenceSession(output_onnx_path, sess_options=options, providers=["CPUExecutionProvider"]
)# disable session.run() fallback mechanism, it prevents for a reset of the execution provider
session.disable_fallback()
text = ['我喜欢吃牛肉面,你喜欢吃什么']
inputs = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
inputs_detach = {k: v.detach().cpu().numpy() for k, v in inputs.items()}# 运行 ONNX 模型
# 这里的logits要有export的时候output_names相对应output = session.run(output_names=['logits'], input_feed=inputs_detach)
embeddings = output[0][:,0]
embeddings
"""
[array([[[ 0.5468873 , -0.44697893, 0.5598697 , ..., -0.31165794,-0.4534812 , 0.35287267],[ 0.8495255 , -0.6043539 , 0.92587775, ..., 0.18488054,-0.3376642 , 0.05710872],[ 1.3958224 , -0.69080853, 1.1740059 , ..., 0.29550147,-0.66349417, 0.32693252],...,[ 1.0015987 , -0.42593402, 0.842302 , ..., -0.5311971 ,0.00608117, -0.6632003 ],[ 0.8438338 , -0.49226436, 0.7351711 , ..., -0.2712023 ,-0.56486 , -0.62722546],[ 0.9577047 , -0.82098204, 0.74133927, ..., -0.3409663 ,-0.41210255, 0.42574972]]], dtype=float32)]
"""
参考
https://blog.csdn.net/weixin_44826203/article/details/127750113
这篇关于BGE 模型转 onnx的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!