babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现

2024-04-05 21:04

本文主要是介绍babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在进入到主程序前,我们还需要看一个Embedding的实现代码,这里的功能主要是为了计算代码之间的相关性。
embedding可以文本中的词语转化为低维实数向量的表示,来计算两段文字间的几何距离来判断词语的含义是否相近。

1. 源码阅读-初始化和计算代码库的嵌入值

这段代码主要是设定了初始化变量,包括使用的embedding的模型,以及tokenizer(分词器),分词器按照\n,作为分词符号和分词长度。

class Embeddings:def __init__(self, workspace_path: str):self.workspace_path = workspace_pathopenai.api_key = os.getenv("OPENAI_API_KEY", "")self.DOC_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.QUERY_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.SEPARATOR = "\n* "self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")self.separator_len = len(self.tokenizer.tokenize(self.SEPARATOR))

下面的代码用于计算整个代码库的embedding值,用于查找相关代码,实现了以下步骤

  • 删除playground_data 代码空间下的所有文件,不会有旧的数据重新计算
  • 将代码文件转换为特定格式,放入repository_info.csv文件中
  • 计算repository_info.csv中内容的嵌入值,放入到doc_embeddings.csv
def compute_repository_embeddings(self):try:playground_data_path = os.path.join(self.workspace_path, 'playground_data')# Delete the contents of the playground_data directory but not the directory itself# This is to ensure that we don't have any old data lying aroundfor filename in os.listdir(playground_data_path):file_path = os.path.join(playground_data_path, filename)try:if os.path.isfile(file_path) or os.path.islink(file_path):os.unlink(file_path)elif os.path.isdir(file_path):shutil.rmtree(file_path)except Exception as e:print(f"Failed to delete {file_path}. Reason: {str(e)}")except Exception as e:print(f"Error: {str(e)}")# extract and save info to csvinfo = self.extract_info(REPOSITORY_PATH)self.save_info_to_csv(info)df = pd.read_csv(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'))df = df.set_index(["filePath", "lineCoverage"])self.df = dfcontext_embeddings = self.compute_doc_embeddings(df)self.save_doc_embeddings_to_csv(context_embeddings, df, os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))try:self.document_embeddings = self.load_embeddings(os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))except:pass

下面是使用到的extract_info函数、save_info_to_csv函数、compute_doc_embeddings函数、save_doc_embeddings_to_csv函数,load_embeddings 函数。

1.1 extract_info 提取代码文件信息

这个函数的功能是从文件中获取信息,转化为特定的形式一个列表,包含三个信息

  • filePath 文件路径
  • lineCoverage 为一个元组,包含两个值 第一行位置和最后一行的位置
  • chunkContent 代码的内容
# Extract information from files in the repository in chunks
# Return a list of [filePath, lineCoverage, chunkContent]
def extract_info(self, REPOSITORY_PATH):# Initialize an empty list to store the informationinfo = []LINES_PER_CHUNK = 60# Iterate through the files in the repositoryfor root, dirs, files in os.walk(REPOSITORY_PATH):for file in files:file_path = os.path.join(root, file)# Read the contents of the filewith open(file_path, "r", encoding="utf-8") as f:try:contents = f.read()except:continue# Split the contents into lineslines = contents.split("\n")# Ignore empty lineslines = [line for line in lines if line.strip()]# Split the lines into chunks of LINES_PER_CHUNK lineschunks = [lines[i:i+LINES_PER_CHUNK]for i in range(0, len(lines), LINES_PER_CHUNK)]# Iterate through the chunksfor i, chunk in enumerate(chunks):# Join the lines in the chunk back into a single stringchunk = "\n".join(chunk)# Get the first and last line numbersfirst_line = i * LINES_PER_CHUNK + 1last_line = first_line + len(chunk.split("\n")) - 1line_coverage = (first_line, last_line)# Add the file path, line coverage, and content to the listinfo.append((os.path.join(root, file), line_coverage, chunk))# Return the list of informationreturn info

1.2 save_info_to_csv保存提取出的信息

这个函数的功能是将代码信息存放到csv文件中,使用pandas库

def save_info_to_csv(self, info):# Open a CSV file for writingos.makedirs(os.path.join(self.workspace_path, "playground_data"), exist_ok=True)with open(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'), "w", newline="") as csvfile:# Create a CSV writerwriter = csv.writer(csvfile)# Write the header rowwriter.writerow(["filePath", "lineCoverage", "content"])# Iterate through the infofor file_path, line_coverage, content in info:# Write a row for each chunk of datawriter.writerow([file_path, line_coverage, content])

1.3 compute_doc_embeddings计算文档的嵌入值信息

计算每个文件的嵌入值,并返回嵌入值字典

def compute_doc_embeddings(self, df: pd.DataFrame) -> dict[tuple[str, str], list[float]]:"""Create an embedding for each row in the dataframe using the OpenAI Embeddings API.Return a dictionary that maps between each embedding vector and the index of the row that it corresponds to."""embeddings = {}for idx, r in df.iterrows():# Wait one second before making the next call to the OpenAI Embeddings API# print("Waiting one second before embedding next row\n")time.sleep(1)embeddings[idx] = self.get_doc_embedding(r.content.replace("\n", " "))return embeddings

1.4 save_doc_embeddings_to_csv 保存嵌入值到文件中

该函数从文件中读取已经保存的embbeding信息,转换为一个dict

  • key为一个元组(filePath, lineCoverage)
  • value为一个数组,把其余列存放至后面
def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]:       df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}

1.5 save_doc_embbedings_to_csv将嵌入值保存到csv文件中

这里处理了一下,不是讲整个嵌入值放到数组中,而是更具嵌入的维度放入到列中,不同的维度有不同的嵌入值

def save_doc_embeddings_to_csv(self, doc_embeddings: dict, df: pd.DataFrame, csv_filepath: str):# Get the dimensionality of the embedding vectors from the first element in the doc_embeddings dictionaryif len(doc_embeddings) == 0:returnEMBEDDING_DIM = len(list(doc_embeddings.values())[0])# Create a new dataframe with the filePath, lineCoverage, and embedding vector columnsembeddings_df = pd.DataFrame(columns=["filePath", "lineCoverage"] + [f"{i}" for i in range(EMBEDDING_DIM)])# Iterate over the rows in the original dataframefor idx, _ in df.iterrows():# Get the embedding vector for the current rowembedding = doc_embeddings[idx]# Create a new row in the embeddings dataframe with the filePath, lineCoverage, and embedding vector valuesrow = [idx[0], idx[1]] + embeddingembeddings_df.loc[len(embeddings_df)] = row# Save the embeddings dataframe to a CSV fileembeddings_df.to_csv(csv_filepath, index=False)

1.6 load_embeddings加载嵌入值,从文件中

def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]:       df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}

2. embedding第二部分-获取代码相关性

获取相关的代码段,根据

  • 任务描述
  • 任务上下文
    获取相关的代码,相似度最高的两个代码块
def get_relevant_code_chunks(self, task_description: str, task_context: str):query = task_description + "\n" + task_contextmost_relevant_document_sections = self.order_document_sections_by_query_similarity(query, self.document_embeddings)selected_chunks = []for _, section_index in most_relevant_document_sections:try:document_section = self.df.loc[section_index]selected_chunks.append(self.SEPARATOR + document_section['content'].replace("\n", " "))if len(selected_chunks) >= 2:breakexcept:passreturn selected_chunks

这个函数有两个参数,他会根据相似度对整个文件排序

  • query 请求的文本,用作计算嵌入值
  • context 上下文,用作查找和这段文本的相似度,就是上文的字典
def order_document_sections_by_query_similarity(self, query: str, contexts: dict[(str, str), np.array]) -> list[(float, (str, str))]:"""Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddingsto find the most relevant sections. Return the list of document sections, sorted by relevance in descending order."""query_embedding = self.get_query_embedding(query)document_similarities = sorted([(self.vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items()], reverse=True)return document_similarities

用数量积的方式计算两个向量之间的相似度,这里数量积可以表示为 a ⋅ b = ∣ a ∣ ∣ b ∣ c o s θ a{\cdot}b=|a||b|cos{\theta} ab=a∣∣bcosθ,当两个向量垂直时,计算的值为0,计算的值越大说明相似度越高

def vector_similarity(self, x: list[float], y: list[float]) -> float:return np.dot(np.array(x), np.array(y))

3. 通过OpenAI获取嵌入值相关函数

计算这段文字的嵌入值

def get_query_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.QUERY_EMBEDDINGS_MODEL)

计算文档相关性

def get_doc_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.DOC_EMBEDDINGS_MODEL)

处理openAi返回

    def get_embedding(self, text: str, model: str) -> list[float]:result = openai.Embedding.create(model=model,input=text)return result["data"][0]["embedding"]

下一篇文章,我们将进入主程序的阅读,看看embedding是如何和主程序结合的

这篇关于babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/878141

相关文章

Nginx实现高并发的项目实践

《Nginx实现高并发的项目实践》本文主要介绍了Nginx实现高并发的项目实践,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录使用最新稳定版本的Nginx合理配置工作进程(workers)配置工作进程连接数(worker_co

python中列表list切分的实现

《python中列表list切分的实现》列表是Python中最常用的数据结构之一,经常需要对列表进行切分操作,本文主要介绍了python中列表list切分的实现,文中通过示例代码介绍的非常详细,对大家... 目录一、列表切片的基本用法1.1 基本切片操作1.2 切片的负索引1.3 切片的省略二、列表切分的高

基于Python实现一个PDF特殊字体提取工具

《基于Python实现一个PDF特殊字体提取工具》在PDF文档处理场景中,我们常常需要针对特定格式的文本内容进行提取分析,本文介绍的PDF特殊字体提取器是一款基于Python开发的桌面应用程序感兴趣的... 目录一、应用背景与功能概述二、技术架构与核心组件2.1 技术选型2.2 系统架构三、核心功能实现解析

Flutter监听当前页面可见与隐藏状态的代码详解

《Flutter监听当前页面可见与隐藏状态的代码详解》文章介绍了如何在Flutter中使用路由观察者来监听应用进入前台或后台状态以及页面的显示和隐藏,并通过代码示例讲解的非常详细,需要的朋友可以参考下... flutter 可以监听 app 进入前台还是后台状态,也可以监听当http://www.cppcn

Python使用PIL库将PNG图片转换为ICO图标的示例代码

《Python使用PIL库将PNG图片转换为ICO图标的示例代码》在软件开发和网站设计中,ICO图标是一种常用的图像格式,特别适用于应用程序图标、网页收藏夹图标等场景,本文将介绍如何使用Python的... 目录引言准备工作代码解析实践操作结果展示结语引言在软件开发和网站设计中,ICO图标是一种常用的图像

使用Python实现表格字段智能去重

《使用Python实现表格字段智能去重》在数据分析和处理过程中,数据清洗是一个至关重要的步骤,其中字段去重是一个常见且关键的任务,下面我们看看如何使用Python进行表格字段智能去重吧... 目录一、引言二、数据重复问题的常见场景与影响三、python在数据清洗中的优势四、基于Python的表格字段智能去重

Spring AI集成DeepSeek实现流式输出的操作方法

《SpringAI集成DeepSeek实现流式输出的操作方法》本文介绍了如何在SpringBoot中使用Sse(Server-SentEvents)技术实现流式输出,后端使用SpringMVC中的S... 目录一、后端代码二、前端代码三、运行项目小天有话说题外话参考资料前面一篇文章我们实现了《Spring

Nginx中location实现多条件匹配的方法详解

《Nginx中location实现多条件匹配的方法详解》在Nginx中,location指令用于匹配请求的URI,虽然location本身是基于单一匹配规则的,但可以通过多种方式实现多个条件的匹配逻辑... 目录1. 概述2. 实现多条件匹配的方式2.1 使用多个 location 块2.2 使用正则表达式

使用Apache POI在Java中实现Excel单元格的合并

《使用ApachePOI在Java中实现Excel单元格的合并》在日常工作中,Excel是一个不可或缺的工具,尤其是在处理大量数据时,本文将介绍如何使用ApachePOI库在Java中实现Excel... 目录工具类介绍工具类代码调用示例依赖配置总结在日常工作中,Excel 是一个不可或缺的工http://

SpringBoot实现导出复杂对象到Excel文件

《SpringBoot实现导出复杂对象到Excel文件》这篇文章主要为大家详细介绍了如何使用Hutool和EasyExcel两种方式来实现在SpringBoot项目中导出复杂对象到Excel文件,需要... 在Spring Boot项目中导出复杂对象到Excel文件,可以利用Hutool或EasyExcel