rasa train nlu详解:1.2-_train_graph()函数

2023-11-12 01:36

本文主要是介绍rasa train nlu详解:1.2-_train_graph()函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  本文使用《使用ResponseSelector实现校园招聘FAQ机器人》中的例子,主要详解介绍_train_graph()函数中变量的具体值。

一.rasa/model_training.py/_train_graph()函数
  _train_graph()函数实现,如下所示:

def _train_graph(file_importer: TrainingDataImporter,training_type: TrainingType,output_path: Text,fixed_model_name: Text,model_to_finetune: Optional[Union[Text, Path]] = None,force_full_training: bool = False,dry_run: bool = False,**kwargs: Any,
) -> TrainingResult:if model_to_finetune:  # 如果有模型微调model_to_finetune = rasa.model.get_model_for_finetuning(model_to_finetune)  # 获取模型微调if not model_to_finetune:  # 如果没有模型微调rasa.shared.utils.cli.print_error_and_exit(  # 打印错误并退出f"No model for finetuning found. Please make sure to either "   # 没有找到微调模型。请确保f"specify a path to a previous model or to have a finetunable " # 要么指定一个以前模型的路径,要么有一个可微调的f"model within the directory '{output_path}'."                  # 在目录'{output_path}'中的模型。)rasa.shared.utils.common.mark_as_experimental_feature(  # 标记为实验性功能"Incremental Training feature"  # 增量训练功能)is_finetuning = model_to_finetune is not None  # 如果有模型微调config = file_importer.get_config()  # 获取配置recipe = Recipe.recipe_for_name(config.get("recipe"))  # 获取配方config, _missing_keys, _configured_keys = recipe.auto_configure(  # 自动配置file_importer.get_config_file_for_auto_config(),  # 获取自动配置的配置文件config,  # 配置training_type,  # 训练类型)model_configuration = recipe.graph_config_for_recipe(  # 配方的graph配置config,  # 配置kwargs,  # 关键字参数training_type=training_type,  # 训练类型is_finetuning=is_finetuning,  # 是否微调)rasa.engine.validation.validate(model_configuration)  # 验证tempdir_name = rasa.utils.common.get_temp_dir_name()  # 获取临时目录名称# Use `TempDirectoryPath` instead of `tempfile.TemporaryDirectory` as this leads to errors on Windows when the context manager tries to delete an already deleted temporary directory (e.g. https://bugs.python.org/issue29982)# 翻译:使用TempDirectoryPath而不是tempfile.TemporaryDirectory,因为当上下文管理器尝试删除已删除的临时目录时,这会导致Windows上的错误(例如https://bugs.python.org/issue29982)with rasa.utils.common.TempDirectoryPath(tempdir_name) as temp_model_dir:  # 临时模型目录model_storage = _create_model_storage(  # 创建模型存储is_finetuning, model_to_finetune, Path(temp_model_dir)  # 是否微调,模型微调,临时模型目录)cache = LocalTrainingCache()  # 本地训练缓存trainer = GraphTrainer(model_storage, cache, DaskGraphRunner)  # Graph训练器if dry_run:  # dry运行fingerprint_status = trainer.fingerprint(                        # fingerprint状态model_configuration.train_schema, file_importer              # 模型配置的训练模式,文件导入器)return _dry_run_result(fingerprint_status, force_full_training)  # 返回dry运行结果model_name = _determine_model_name(fixed_model_name, training_type)  # 确定模型名称full_model_path = Path(output_path, model_name)                # 完整的模型路径with telemetry.track_model_training(                    # 跟踪模型训练file_importer, model_type=training_type.model_type  # 文件导入器,模型类型):trainer.train(                               # 训练model_configuration,                     # 模型配置file_importer,                           # 文件导入器full_model_path,                         # 完整的模型路径force_retraining=force_full_training,    # 强制重新训练is_finetuning=is_finetuning,             # 是否微调)rasa.shared.utils.cli.print_success(         # 打印成功f"Your Rasa model is trained and saved at '{full_model_path}'."  # Rasa模型已经训练并保存在'{full_model_path}'。)return TrainingResult(str(full_model_path), 0)   # 训练结果

1.传递来的形参数据

2._train_graph()函数组成
  该函数主要由3个方法组成,如下所示:

  • model_configuration = recipe.graph_config_for_recipe(*)
  • trainer = GraphTrainer(model_storage, cache, DaskGraphRunner)
  • trainer.train(model_configuration, file_importer, full_model_path, force_retraining, is_finetuning)

二._train_graph()函数中的方法
1.file_importer.get_config()
  将config.yml文件转化为dict类型,如下所示:

2.Recipe.recipe_for_name(config.get(“recipe”))

(1)ENTITY_EXTRACTOR = ComponentType.ENTITY_EXTRACTOR
实体抽取器。
(2)INTENT_CLASSIFIER = ComponentType.INTENT_CLASSIFIER
意图分类器。
(3)MESSAGE_FEATURIZER = ComponentType.MESSAGE_FEATURIZER
消息特征化。
(4)MESSAGE_TOKENIZER = ComponentType.MESSAGE_TOKENIZER
消息Tokenizer。
(5)MODEL_LOADER = ComponentType.MODEL_LOADER
模型加载器。
(6)POLICY_WITHOUT_END_TO_END_SUPPORT = ComponentType.POLICY_WITHOUT_END_TO_END_SUPPORT
非端到端策略支持。
(7)POLICY_WITH_END_TO_END_SUPPORT = ComponentType.POLICY_WITH_END_TO_END_SUPPORT
端到端策略支持。

3.model_configuration = recipe.graph_config_for_recipe(*)
  model_configuration.train_schema和model_configuration.predict_schema的数据类型都是GraphSchema类对象,分别表示在训练和预测时所需要的SchemaNode,以及SchemaNode在GraphSchema中的依赖关系。

(1)model_configuration.train_schema

  • schema_validator:rasa.graph_components.validators.default_recipe_validator.DefaultV1RecipeValidator类中的validate方法
  • finetuning_validator:rasa.graph_components.validators.finetuning_validator.FinetuningValidator类中的validate方法
  • nlu_training_data_provider:rasa.graph_components.providers.nlu_training_data_provider.NLUTrainingDataProvider类中的provide方法
  • train_JiebaTokenizer0:rasa.nlu.tokenizers.jieba_tokenizer.JiebaTokenizer类中的train方法
  • run_JiebaTokenizer0:rasa.nlu.tokenizers.jieba_tokenizer.JiebaTokenizer类中的process_training_data方法
  • run_LanguageModelFeaturizer1:rasa.nlu.featurizers.dense_featurizer.lm_featurizer.LanguageModelFeaturizer类中的process_training_data方法
  • train_DIETClassifier2:rasa.nlu.classifiers.diet_classifier.DIETClassifier类中的train方法
  • train_ResponseSelector3:rasa.nlu.selectors.response_selector.ResponseSelector类中的train方法

说明:ResponseSelector类继承自DIETClassifier类。

(2)model_configuration.predict_schema

  • nlu_message_converter:rasa.graph_components.converters.nlu_message_converter.NLUMessageConverter类中的convert_user_message方法
  • run_JiebaTokenizer0:rasa.nlu.tokenizers.jieba_tokenizer.JiebaTokenizer类中的process方法
  • run_LanguageModelFeaturizer1:rasa.nlu.featurizers.dense_featurizer.lm_featurizer.LanguageModelFeaturizer类中的process方法
  • run_DIETClassifier2:rasa.nlu.classifiers.diet_classifier.DIETClassifier类中的process方法
  • run_ResponseSelector3:rasa.nlu.selectors.response_selector.ResponseSelector类中的process方法
  • run_RegexMessageHandler:rasa.nlu.classifiers.regex_message_handler.RegexMessageHandler类中的process方法

4.tempdir_name
  ‘C:\Users\ADMINI~1\AppData\Local\Temp\tmpg0v179ea’

5.trainer = GraphTrainer(*)和trainer.train(*)
  这里执行的代码是rasa/engine/training/graph_trainer.py中GraphTrainer类的train()方法,实现功能为训练和打包模型并返回预测graph运行程序。

6.Rasa中GraphComponent的子类


参考文献:
[1]https://github.com/RasaHQ/rasa
[2]rasa 3.2.10 NLU模块的训练:https://zhuanlan.zhihu.com/p/574935615
[3]rasa.engine.graph:https://rasa.com/docs/rasa/next/reference/rasa/engine/graph/

这篇关于rasa train nlu详解:1.2-_train_graph()函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/394195

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

usaco 1.2 Palindromic Squares(进制转化)

考察进制转化 注意一些细节就可以了 直接上代码: /*ID: who jayLANG: C++TASK: palsquare*/#include<stdio.h>int x[20],xlen,y[20],ylen,B;void change(int n){int m;m=n;xlen=0;while(m){x[++xlen]=m%B;m/=B;}m=n*n;ylen=0;whi

usaco 1.2 Name That Number(数字字母转化)

巧妙的利用code[b[0]-'A'] 将字符ABC...Z转换为数字 需要注意的是重新开一个数组 c [ ] 存储字符串 应人为的在末尾附上 ‘ \ 0 ’ 详见代码: /*ID: who jayLANG: C++TASK: namenum*/#include<stdio.h>#include<string.h>int main(){FILE *fin = fopen (

usaco 1.2 Milking Cows(类hash表)

第一种思路被卡了时间 到第二种思路的时候就觉得第一种思路太坑爹了 代码又长又臭还超时!! 第一种思路:我不知道为什么最后一组数据会被卡 超时超了0.2s左右 大概想法是 快排加一个遍历 先将开始时间按升序排好 然后开始遍历比较 1 若 下一个开始beg[i] 小于 tem_end 则说明本组数据与上组数据是在连续的一个区间 取max( ed[i],tem_end ) 2 反之 这个

usaco 1.2 Transformations(模拟)

我的做法就是一个一个情况枚举出来 注意计算公式: ( 变换后的矩阵记为C) 顺时针旋转90°:C[i] [j]=A[n-j-1] [i] (旋转180°和270° 可以多转几个九十度来推) 对称:C[i] [n-j-1]=A[i] [j] 代码有点长 。。。 /*ID: who jayLANG: C++TASK: transform*/#include<

6.1.数据结构-c/c++堆详解下篇(堆排序,TopK问题)

上篇:6.1.数据结构-c/c++模拟实现堆上篇(向下,上调整算法,建堆,增删数据)-CSDN博客 本章重点 1.使用堆来完成堆排序 2.使用堆解决TopK问题 目录 一.堆排序 1.1 思路 1.2 代码 1.3 简单测试 二.TopK问题 2.1 思路(求最小): 2.2 C语言代码(手写堆) 2.3 C++代码(使用优先级队列 priority_queue)

K8S(Kubernetes)开源的容器编排平台安装步骤详解

K8S(Kubernetes)是一个开源的容器编排平台,用于自动化部署、扩展和管理容器化应用程序。以下是K8S容器编排平台的安装步骤、使用方式及特点的概述: 安装步骤: 安装Docker:K8S需要基于Docker来运行容器化应用程序。首先要在所有节点上安装Docker引擎。 安装Kubernetes Master:在集群中选择一台主机作为Master节点,安装K8S的控制平面组件,如AP

C++操作符重载实例(独立函数)

C++操作符重载实例,我们把坐标值CVector的加法进行重载,计算c3=c1+c2时,也就是计算x3=x1+x2,y3=y1+y2,今天我们以独立函数的方式重载操作符+(加号),以下是C++代码: c1802.cpp源代码: D:\YcjWork\CppTour>vim c1802.cpp #include <iostream>using namespace std;/*** 以独立函数