babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现

2024-04-05 21:04

本文主要是介绍babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在进入到主程序前,我们还需要看一个Embedding的实现代码,这里的功能主要是为了计算代码之间的相关性。
embedding可以文本中的词语转化为低维实数向量的表示,来计算两段文字间的几何距离来判断词语的含义是否相近。

1. 源码阅读-初始化和计算代码库的嵌入值

这段代码主要是设定了初始化变量,包括使用的embedding的模型,以及tokenizer(分词器),分词器按照\n,作为分词符号和分词长度。

class Embeddings:def __init__(self, workspace_path: str):self.workspace_path = workspace_pathopenai.api_key = os.getenv("OPENAI_API_KEY", "")self.DOC_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.QUERY_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.SEPARATOR = "\n* "self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")self.separator_len = len(self.tokenizer.tokenize(self.SEPARATOR))

下面的代码用于计算整个代码库的embedding值,用于查找相关代码,实现了以下步骤

  • 删除playground_data 代码空间下的所有文件,不会有旧的数据重新计算
  • 将代码文件转换为特定格式,放入repository_info.csv文件中
  • 计算repository_info.csv中内容的嵌入值,放入到doc_embeddings.csv
def compute_repository_embeddings(self):try:playground_data_path = os.path.join(self.workspace_path, 'playground_data')# Delete the contents of the playground_data directory but not the directory itself# This is to ensure that we don't have any old data lying aroundfor filename in os.listdir(playground_data_path):file_path = os.path.join(playground_data_path, filename)try:if os.path.isfile(file_path) or os.path.islink(file_path):os.unlink(file_path)elif os.path.isdir(file_path):shutil.rmtree(file_path)except Exception as e:print(f"Failed to delete {file_path}. Reason: {str(e)}")except Exception as e:print(f"Error: {str(e)}")# extract and save info to csvinfo = self.extract_info(REPOSITORY_PATH)self.save_info_to_csv(info)df = pd.read_csv(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'))df = df.set_index(["filePath", "lineCoverage"])self.df = dfcontext_embeddings = self.compute_doc_embeddings(df)self.save_doc_embeddings_to_csv(context_embeddings, df, os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))try:self.document_embeddings = self.load_embeddings(os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))except:pass

下面是使用到的extract_info函数、save_info_to_csv函数、compute_doc_embeddings函数、save_doc_embeddings_to_csv函数,load_embeddings 函数。

1.1 extract_info 提取代码文件信息

这个函数的功能是从文件中获取信息,转化为特定的形式一个列表,包含三个信息

  • filePath 文件路径
  • lineCoverage 为一个元组,包含两个值 第一行位置和最后一行的位置
  • chunkContent 代码的内容
# Extract information from files in the repository in chunks
# Return a list of [filePath, lineCoverage, chunkContent]
def extract_info(self, REPOSITORY_PATH):# Initialize an empty list to store the informationinfo = []LINES_PER_CHUNK = 60# Iterate through the files in the repositoryfor root, dirs, files in os.walk(REPOSITORY_PATH):for file in files:file_path = os.path.join(root, file)# Read the contents of the filewith open(file_path, "r", encoding="utf-8") as f:try:contents = f.read()except:continue# Split the contents into lineslines = contents.split("\n")# Ignore empty lineslines = [line for line in lines if line.strip()]# Split the lines into chunks of LINES_PER_CHUNK lineschunks = [lines[i:i+LINES_PER_CHUNK]for i in range(0, len(lines), LINES_PER_CHUNK)]# Iterate through the chunksfor i, chunk in enumerate(chunks):# Join the lines in the chunk back into a single stringchunk = "\n".join(chunk)# Get the first and last line numbersfirst_line = i * LINES_PER_CHUNK + 1last_line = first_line + len(chunk.split("\n")) - 1line_coverage = (first_line, last_line)# Add the file path, line coverage, and content to the listinfo.append((os.path.join(root, file), line_coverage, chunk))# Return the list of informationreturn info

1.2 save_info_to_csv保存提取出的信息

这个函数的功能是将代码信息存放到csv文件中,使用pandas库

def save_info_to_csv(self, info):# Open a CSV file for writingos.makedirs(os.path.join(self.workspace_path, "playground_data"), exist_ok=True)with open(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'), "w", newline="") as csvfile:# Create a CSV writerwriter = csv.writer(csvfile)# Write the header rowwriter.writerow(["filePath", "lineCoverage", "content"])# Iterate through the infofor file_path, line_coverage, content in info:# Write a row for each chunk of datawriter.writerow([file_path, line_coverage, content])

1.3 compute_doc_embeddings计算文档的嵌入值信息

计算每个文件的嵌入值,并返回嵌入值字典

def compute_doc_embeddings(self, df: pd.DataFrame) -> dict[tuple[str, str], list[float]]:"""Create an embedding for each row in the dataframe using the OpenAI Embeddings API.Return a dictionary that maps between each embedding vector and the index of the row that it corresponds to."""embeddings = {}for idx, r in df.iterrows():# Wait one second before making the next call to the OpenAI Embeddings API# print("Waiting one second before embedding next row\n")time.sleep(1)embeddings[idx] = self.get_doc_embedding(r.content.replace("\n", " "))return embeddings

1.4 save_doc_embeddings_to_csv 保存嵌入值到文件中

该函数从文件中读取已经保存的embbeding信息,转换为一个dict

  • key为一个元组(filePath, lineCoverage)
  • value为一个数组,把其余列存放至后面
def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]:       df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}

1.5 save_doc_embbedings_to_csv将嵌入值保存到csv文件中

这里处理了一下,不是讲整个嵌入值放到数组中,而是更具嵌入的维度放入到列中,不同的维度有不同的嵌入值

def save_doc_embeddings_to_csv(self, doc_embeddings: dict, df: pd.DataFrame, csv_filepath: str):# Get the dimensionality of the embedding vectors from the first element in the doc_embeddings dictionaryif len(doc_embeddings) == 0:returnEMBEDDING_DIM = len(list(doc_embeddings.values())[0])# Create a new dataframe with the filePath, lineCoverage, and embedding vector columnsembeddings_df = pd.DataFrame(columns=["filePath", "lineCoverage"] + [f"{i}" for i in range(EMBEDDING_DIM)])# Iterate over the rows in the original dataframefor idx, _ in df.iterrows():# Get the embedding vector for the current rowembedding = doc_embeddings[idx]# Create a new row in the embeddings dataframe with the filePath, lineCoverage, and embedding vector valuesrow = [idx[0], idx[1]] + embeddingembeddings_df.loc[len(embeddings_df)] = row# Save the embeddings dataframe to a CSV fileembeddings_df.to_csv(csv_filepath, index=False)

1.6 load_embeddings加载嵌入值,从文件中

def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]:       df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}

2. embedding第二部分-获取代码相关性

获取相关的代码段,根据

  • 任务描述
  • 任务上下文
    获取相关的代码,相似度最高的两个代码块
def get_relevant_code_chunks(self, task_description: str, task_context: str):query = task_description + "\n" + task_contextmost_relevant_document_sections = self.order_document_sections_by_query_similarity(query, self.document_embeddings)selected_chunks = []for _, section_index in most_relevant_document_sections:try:document_section = self.df.loc[section_index]selected_chunks.append(self.SEPARATOR + document_section['content'].replace("\n", " "))if len(selected_chunks) >= 2:breakexcept:passreturn selected_chunks

这个函数有两个参数,他会根据相似度对整个文件排序

  • query 请求的文本,用作计算嵌入值
  • context 上下文,用作查找和这段文本的相似度,就是上文的字典
def order_document_sections_by_query_similarity(self, query: str, contexts: dict[(str, str), np.array]) -> list[(float, (str, str))]:"""Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddingsto find the most relevant sections. Return the list of document sections, sorted by relevance in descending order."""query_embedding = self.get_query_embedding(query)document_similarities = sorted([(self.vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items()], reverse=True)return document_similarities

用数量积的方式计算两个向量之间的相似度,这里数量积可以表示为 a ⋅ b = ∣ a ∣ ∣ b ∣ c o s θ a{\cdot}b=|a||b|cos{\theta} ab=a∣∣bcosθ,当两个向量垂直时,计算的值为0,计算的值越大说明相似度越高

def vector_similarity(self, x: list[float], y: list[float]) -> float:return np.dot(np.array(x), np.array(y))

3. 通过OpenAI获取嵌入值相关函数

计算这段文字的嵌入值

def get_query_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.QUERY_EMBEDDINGS_MODEL)

计算文档相关性

def get_doc_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.DOC_EMBEDDINGS_MODEL)

处理openAi返回

    def get_embedding(self, text: str, model: str) -> list[float]:result = openai.Embedding.create(model=model,input=text)return result["data"][0]["embedding"]

下一篇文章,我们将进入主程序的阅读,看看embedding是如何和主程序结合的

这篇关于babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/878141

相关文章

Java 正则表达式URL 匹配与源码全解析

《Java正则表达式URL匹配与源码全解析》在Web应用开发中,我们经常需要对URL进行格式验证,今天我们结合Java的Pattern和Matcher类,深入理解正则表达式在实际应用中... 目录1.正则表达式分解:2. 添加域名匹配 (2)3. 添加路径和查询参数匹配 (3) 4. 最终优化版本5.设计思

C#实现将Excel表格转换为图片(JPG/ PNG)

《C#实现将Excel表格转换为图片(JPG/PNG)》Excel表格可能会因为不同设备或字体缺失等问题,导致格式错乱或数据显示异常,转换为图片后,能确保数据的排版等保持一致,下面我们看看如何使用C... 目录通过C# 转换Excel工作表到图片通过C# 转换指定单元格区域到图片知识扩展C# 将 Excel

基于Java实现回调监听工具类

《基于Java实现回调监听工具类》这篇文章主要为大家详细介绍了如何基于Java实现一个回调监听工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录监听接口类 Listenable实际用法打印结果首先,会用到 函数式接口 Consumer, 通过这个可以解耦回调方法,下面先写一个

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析

Qt中QGroupBox控件的实现

《Qt中QGroupBox控件的实现》QGroupBox是Qt框架中一个非常有用的控件,它主要用于组织和管理一组相关的控件,本文主要介绍了Qt中QGroupBox控件的实现,具有一定的参考价值,感兴趣... 目录引言一、基本属性二、常用方法2.1 构造函数 2.2 设置标题2.3 设置复选框模式2.4 是否

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

springboot整合阿里云百炼DeepSeek实现sse流式打印的操作方法

《springboot整合阿里云百炼DeepSeek实现sse流式打印的操作方法》:本文主要介绍springboot整合阿里云百炼DeepSeek实现sse流式打印,本文给大家介绍的非常详细,对大... 目录1.开通阿里云百炼,获取到key2.新建SpringBoot项目3.工具类4.启动类5.测试类6.测

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

SpringBoot集成Milvus实现数据增删改查功能

《SpringBoot集成Milvus实现数据增删改查功能》milvus支持的语言比较多,支持python,Java,Go,node等开发语言,本文主要介绍如何使用Java语言,采用springboo... 目录1、Milvus基本概念2、添加maven依赖3、配置yml文件4、创建MilvusClient

JS+HTML实现在线图片水印添加工具

《JS+HTML实现在线图片水印添加工具》在社交媒体和内容创作日益频繁的今天,如何保护原创内容、展示品牌身份成了一个不得不面对的问题,本文将实现一个完全基于HTML+CSS构建的现代化图片水印在线工具... 目录概述功能亮点使用方法技术解析延伸思考运行效果项目源码下载总结概述在社交媒体和内容创作日益频繁的