babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现

2024-04-05 21:04

本文主要是介绍babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在进入到主程序前,我们还需要看一个Embedding的实现代码,这里的功能主要是为了计算代码之间的相关性。
embedding可以文本中的词语转化为低维实数向量的表示,来计算两段文字间的几何距离来判断词语的含义是否相近。

1. 源码阅读-初始化和计算代码库的嵌入值

这段代码主要是设定了初始化变量,包括使用的embedding的模型,以及tokenizer(分词器),分词器按照\n,作为分词符号和分词长度。

class Embeddings:def __init__(self, workspace_path: str):self.workspace_path = workspace_pathopenai.api_key = os.getenv("OPENAI_API_KEY", "")self.DOC_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.QUERY_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.SEPARATOR = "\n* "self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")self.separator_len = len(self.tokenizer.tokenize(self.SEPARATOR))

下面的代码用于计算整个代码库的embedding值,用于查找相关代码,实现了以下步骤

  • 删除playground_data 代码空间下的所有文件,不会有旧的数据重新计算
  • 将代码文件转换为特定格式,放入repository_info.csv文件中
  • 计算repository_info.csv中内容的嵌入值,放入到doc_embeddings.csv
def compute_repository_embeddings(self):try:playground_data_path = os.path.join(self.workspace_path, 'playground_data')# Delete the contents of the playground_data directory but not the directory itself# This is to ensure that we don't have any old data lying aroundfor filename in os.listdir(playground_data_path):file_path = os.path.join(playground_data_path, filename)try:if os.path.isfile(file_path) or os.path.islink(file_path):os.unlink(file_path)elif os.path.isdir(file_path):shutil.rmtree(file_path)except Exception as e:print(f"Failed to delete {file_path}. Reason: {str(e)}")except Exception as e:print(f"Error: {str(e)}")# extract and save info to csvinfo = self.extract_info(REPOSITORY_PATH)self.save_info_to_csv(info)df = pd.read_csv(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'))df = df.set_index(["filePath", "lineCoverage"])self.df = dfcontext_embeddings = self.compute_doc_embeddings(df)self.save_doc_embeddings_to_csv(context_embeddings, df, os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))try:self.document_embeddings = self.load_embeddings(os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))except:pass

下面是使用到的extract_info函数、save_info_to_csv函数、compute_doc_embeddings函数、save_doc_embeddings_to_csv函数,load_embeddings 函数。

1.1 extract_info 提取代码文件信息

这个函数的功能是从文件中获取信息,转化为特定的形式一个列表,包含三个信息

  • filePath 文件路径
  • lineCoverage 为一个元组,包含两个值 第一行位置和最后一行的位置
  • chunkContent 代码的内容
# Extract information from files in the repository in chunks
# Return a list of [filePath, lineCoverage, chunkContent]
def extract_info(self, REPOSITORY_PATH):# Initialize an empty list to store the informationinfo = []LINES_PER_CHUNK = 60# Iterate through the files in the repositoryfor root, dirs, files in os.walk(REPOSITORY_PATH):for file in files:file_path = os.path.join(root, file)# Read the contents of the filewith open(file_path, "r", encoding="utf-8") as f:try:contents = f.read()except:continue# Split the contents into lineslines = contents.split("\n")# Ignore empty lineslines = [line for line in lines if line.strip()]# Split the lines into chunks of LINES_PER_CHUNK lineschunks = [lines[i:i+LINES_PER_CHUNK]for i in range(0, len(lines), LINES_PER_CHUNK)]# Iterate through the chunksfor i, chunk in enumerate(chunks):# Join the lines in the chunk back into a single stringchunk = "\n".join(chunk)# Get the first and last line numbersfirst_line = i * LINES_PER_CHUNK + 1last_line = first_line + len(chunk.split("\n")) - 1line_coverage = (first_line, last_line)# Add the file path, line coverage, and content to the listinfo.append((os.path.join(root, file), line_coverage, chunk))# Return the list of informationreturn info

1.2 save_info_to_csv保存提取出的信息

这个函数的功能是将代码信息存放到csv文件中,使用pandas库

def save_info_to_csv(self, info):# Open a CSV file for writingos.makedirs(os.path.join(self.workspace_path, "playground_data"), exist_ok=True)with open(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'), "w", newline="") as csvfile:# Create a CSV writerwriter = csv.writer(csvfile)# Write the header rowwriter.writerow(["filePath", "lineCoverage", "content"])# Iterate through the infofor file_path, line_coverage, content in info:# Write a row for each chunk of datawriter.writerow([file_path, line_coverage, content])

1.3 compute_doc_embeddings计算文档的嵌入值信息

计算每个文件的嵌入值,并返回嵌入值字典

def compute_doc_embeddings(self, df: pd.DataFrame) -> dict[tuple[str, str], list[float]]:"""Create an embedding for each row in the dataframe using the OpenAI Embeddings API.Return a dictionary that maps between each embedding vector and the index of the row that it corresponds to."""embeddings = {}for idx, r in df.iterrows():# Wait one second before making the next call to the OpenAI Embeddings API# print("Waiting one second before embedding next row\n")time.sleep(1)embeddings[idx] = self.get_doc_embedding(r.content.replace("\n", " "))return embeddings

1.4 save_doc_embeddings_to_csv 保存嵌入值到文件中

该函数从文件中读取已经保存的embbeding信息,转换为一个dict

  • key为一个元组(filePath, lineCoverage)
  • value为一个数组,把其余列存放至后面
def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]:       df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}

1.5 save_doc_embbedings_to_csv将嵌入值保存到csv文件中

这里处理了一下,不是讲整个嵌入值放到数组中,而是更具嵌入的维度放入到列中,不同的维度有不同的嵌入值

def save_doc_embeddings_to_csv(self, doc_embeddings: dict, df: pd.DataFrame, csv_filepath: str):# Get the dimensionality of the embedding vectors from the first element in the doc_embeddings dictionaryif len(doc_embeddings) == 0:returnEMBEDDING_DIM = len(list(doc_embeddings.values())[0])# Create a new dataframe with the filePath, lineCoverage, and embedding vector columnsembeddings_df = pd.DataFrame(columns=["filePath", "lineCoverage"] + [f"{i}" for i in range(EMBEDDING_DIM)])# Iterate over the rows in the original dataframefor idx, _ in df.iterrows():# Get the embedding vector for the current rowembedding = doc_embeddings[idx]# Create a new row in the embeddings dataframe with the filePath, lineCoverage, and embedding vector valuesrow = [idx[0], idx[1]] + embeddingembeddings_df.loc[len(embeddings_df)] = row# Save the embeddings dataframe to a CSV fileembeddings_df.to_csv(csv_filepath, index=False)

1.6 load_embeddings加载嵌入值,从文件中

def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]:       df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}

2. embedding第二部分-获取代码相关性

获取相关的代码段,根据

  • 任务描述
  • 任务上下文
    获取相关的代码,相似度最高的两个代码块
def get_relevant_code_chunks(self, task_description: str, task_context: str):query = task_description + "\n" + task_contextmost_relevant_document_sections = self.order_document_sections_by_query_similarity(query, self.document_embeddings)selected_chunks = []for _, section_index in most_relevant_document_sections:try:document_section = self.df.loc[section_index]selected_chunks.append(self.SEPARATOR + document_section['content'].replace("\n", " "))if len(selected_chunks) >= 2:breakexcept:passreturn selected_chunks

这个函数有两个参数,他会根据相似度对整个文件排序

  • query 请求的文本,用作计算嵌入值
  • context 上下文,用作查找和这段文本的相似度,就是上文的字典
def order_document_sections_by_query_similarity(self, query: str, contexts: dict[(str, str), np.array]) -> list[(float, (str, str))]:"""Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddingsto find the most relevant sections. Return the list of document sections, sorted by relevance in descending order."""query_embedding = self.get_query_embedding(query)document_similarities = sorted([(self.vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items()], reverse=True)return document_similarities

用数量积的方式计算两个向量之间的相似度,这里数量积可以表示为 a ⋅ b = ∣ a ∣ ∣ b ∣ c o s θ a{\cdot}b=|a||b|cos{\theta} ab=a∣∣bcosθ,当两个向量垂直时,计算的值为0,计算的值越大说明相似度越高

def vector_similarity(self, x: list[float], y: list[float]) -> float:return np.dot(np.array(x), np.array(y))

3. 通过OpenAI获取嵌入值相关函数

计算这段文字的嵌入值

def get_query_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.QUERY_EMBEDDINGS_MODEL)

计算文档相关性

def get_doc_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.DOC_EMBEDDINGS_MODEL)

处理openAi返回

    def get_embedding(self, text: str, model: str) -> list[float]:result = openai.Embedding.create(model=model,input=text)return result["data"][0]["embedding"]

下一篇文章,我们将进入主程序的阅读,看看embedding是如何和主程序结合的

这篇关于babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/878141

相关文章

python使用watchdog实现文件资源监控

《python使用watchdog实现文件资源监控》watchdog支持跨平台文件资源监控,可以检测指定文件夹下文件及文件夹变动,下面我们来看看Python如何使用watchdog实现文件资源监控吧... python文件监控库watchdogs简介随着Python在各种应用领域中的广泛使用,其生态环境也

el-select下拉选择缓存的实现

《el-select下拉选择缓存的实现》本文主要介绍了在使用el-select实现下拉选择缓存时遇到的问题及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录项目场景:问题描述解决方案:项目场景:从左侧列表中选取字段填入右侧下拉多选框,用户可以对右侧

SpringCloud集成AlloyDB的示例代码

《SpringCloud集成AlloyDB的示例代码》AlloyDB是GoogleCloud提供的一种高度可扩展、强性能的关系型数据库服务,它兼容PostgreSQL,并提供了更快的查询性能... 目录1.AlloyDBjavascript是什么?AlloyDB 的工作原理2.搭建测试环境3.代码工程1.

Java调用Python代码的几种方法小结

《Java调用Python代码的几种方法小结》Python语言有丰富的系统管理、数据处理、统计类软件包,因此从java应用中调用Python代码的需求很常见、实用,本文介绍几种方法从java调用Pyt... 目录引言Java core使用ProcessBuilder使用Java脚本引擎总结引言python

Java中ArrayList的8种浅拷贝方式示例代码

《Java中ArrayList的8种浅拷贝方式示例代码》:本文主要介绍Java中ArrayList的8种浅拷贝方式的相关资料,讲解了Java中ArrayList的浅拷贝概念,并详细分享了八种实现浅... 目录引言什么是浅拷贝?ArrayList 浅拷贝的重要性方法一:使用构造函数方法二:使用 addAll(

Python pyinstaller实现图形化打包工具

《Pythonpyinstaller实现图形化打包工具》:本文主要介绍一个使用PythonPYQT5制作的关于pyinstaller打包工具,代替传统的cmd黑窗口模式打包页面,实现更快捷方便的... 目录1.简介2.运行效果3.相关源码1.简介一个使用python PYQT5制作的关于pyinstall

使用Python实现大文件切片上传及断点续传的方法

《使用Python实现大文件切片上传及断点续传的方法》本文介绍了使用Python实现大文件切片上传及断点续传的方法,包括功能模块划分(获取上传文件接口状态、临时文件夹状态信息、切片上传、切片合并)、整... 目录概要整体架构流程技术细节获取上传文件状态接口获取临时文件夹状态信息接口切片上传功能文件合并功能小

python实现自动登录12306自动抢票功能

《python实现自动登录12306自动抢票功能》随着互联网技术的发展,越来越多的人选择通过网络平台购票,特别是在中国,12306作为官方火车票预订平台,承担了巨大的访问量,对于热门线路或者节假日出行... 目录一、遇到的问题?二、改进三、进阶–展望总结一、遇到的问题?1.url-正确的表头:就是首先ur

C#实现文件读写到SQLite数据库

《C#实现文件读写到SQLite数据库》这篇文章主要为大家详细介绍了使用C#将文件读写到SQLite数据库的几种方法,文中的示例代码讲解详细,感兴趣的小伙伴可以参考一下... 目录1. 使用 BLOB 存储文件2. 存储文件路径3. 分块存储文件《文件读写到SQLite数据库China编程的方法》博客中,介绍了文

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步