本文主要是介绍babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在进入到主程序前,我们还需要看一个Embedding
的实现代码,这里的功能主要是为了计算代码之间的相关性。
embedding可以文本中的词语转化为低维实数向量的表示,来计算两段文字间的几何距离来判断词语的含义是否相近。
1. 源码阅读-初始化和计算代码库的嵌入值
这段代码主要是设定了初始化变量,包括使用的embedding的模型,以及tokenizer(分词器),分词器按照\n
,作为分词符号和分词长度。
class Embeddings:def __init__(self, workspace_path: str):self.workspace_path = workspace_pathopenai.api_key = os.getenv("OPENAI_API_KEY", "")self.DOC_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.QUERY_EMBEDDINGS_MODEL = f"text-embedding-ada-002"self.SEPARATOR = "\n* "self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")self.separator_len = len(self.tokenizer.tokenize(self.SEPARATOR))
下面的代码用于计算整个代码库的embedding值,用于查找相关代码,实现了以下步骤
- 删除
playground_data
代码空间下的所有文件,不会有旧的数据重新计算 - 将代码文件转换为特定格式,放入
repository_info.csv
文件中 - 计算
repository_info.csv
中内容的嵌入值,放入到doc_embeddings.csv
中
def compute_repository_embeddings(self):try:playground_data_path = os.path.join(self.workspace_path, 'playground_data')# Delete the contents of the playground_data directory but not the directory itself# This is to ensure that we don't have any old data lying aroundfor filename in os.listdir(playground_data_path):file_path = os.path.join(playground_data_path, filename)try:if os.path.isfile(file_path) or os.path.islink(file_path):os.unlink(file_path)elif os.path.isdir(file_path):shutil.rmtree(file_path)except Exception as e:print(f"Failed to delete {file_path}. Reason: {str(e)}")except Exception as e:print(f"Error: {str(e)}")# extract and save info to csvinfo = self.extract_info(REPOSITORY_PATH)self.save_info_to_csv(info)df = pd.read_csv(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'))df = df.set_index(["filePath", "lineCoverage"])self.df = dfcontext_embeddings = self.compute_doc_embeddings(df)self.save_doc_embeddings_to_csv(context_embeddings, df, os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))try:self.document_embeddings = self.load_embeddings(os.path.join(self.workspace_path, 'playground_data\\doc_embeddings.csv'))except:pass
下面是使用到的extract_info
函数、save_info_to_csv
函数、compute_doc_embeddings
函数、save_doc_embeddings_to_csv
函数,load_embeddings
函数。
1.1 extract_info
提取代码文件信息
这个函数的功能是从文件中获取信息,转化为特定的形式一个列表,包含三个信息
filePath
文件路径lineCoverage
为一个元组,包含两个值 第一行位置和最后一行的位置chunkContent
代码的内容
# Extract information from files in the repository in chunks
# Return a list of [filePath, lineCoverage, chunkContent]
def extract_info(self, REPOSITORY_PATH):# Initialize an empty list to store the informationinfo = []LINES_PER_CHUNK = 60# Iterate through the files in the repositoryfor root, dirs, files in os.walk(REPOSITORY_PATH):for file in files:file_path = os.path.join(root, file)# Read the contents of the filewith open(file_path, "r", encoding="utf-8") as f:try:contents = f.read()except:continue# Split the contents into lineslines = contents.split("\n")# Ignore empty lineslines = [line for line in lines if line.strip()]# Split the lines into chunks of LINES_PER_CHUNK lineschunks = [lines[i:i+LINES_PER_CHUNK]for i in range(0, len(lines), LINES_PER_CHUNK)]# Iterate through the chunksfor i, chunk in enumerate(chunks):# Join the lines in the chunk back into a single stringchunk = "\n".join(chunk)# Get the first and last line numbersfirst_line = i * LINES_PER_CHUNK + 1last_line = first_line + len(chunk.split("\n")) - 1line_coverage = (first_line, last_line)# Add the file path, line coverage, and content to the listinfo.append((os.path.join(root, file), line_coverage, chunk))# Return the list of informationreturn info
1.2 save_info_to_csv
保存提取出的信息
这个函数的功能是将代码信息存放到csv文件中,使用pandas库
def save_info_to_csv(self, info):# Open a CSV file for writingos.makedirs(os.path.join(self.workspace_path, "playground_data"), exist_ok=True)with open(os.path.join(self.workspace_path, 'playground_data\\repository_info.csv'), "w", newline="") as csvfile:# Create a CSV writerwriter = csv.writer(csvfile)# Write the header rowwriter.writerow(["filePath", "lineCoverage", "content"])# Iterate through the infofor file_path, line_coverage, content in info:# Write a row for each chunk of datawriter.writerow([file_path, line_coverage, content])
1.3 compute_doc_embeddings
计算文档的嵌入值信息
计算每个文件的嵌入值,并返回嵌入值字典
def compute_doc_embeddings(self, df: pd.DataFrame) -> dict[tuple[str, str], list[float]]:"""Create an embedding for each row in the dataframe using the OpenAI Embeddings API.Return a dictionary that maps between each embedding vector and the index of the row that it corresponds to."""embeddings = {}for idx, r in df.iterrows():# Wait one second before making the next call to the OpenAI Embeddings API# print("Waiting one second before embedding next row\n")time.sleep(1)embeddings[idx] = self.get_doc_embedding(r.content.replace("\n", " "))return embeddings
1.4 save_doc_embeddings_to_csv
保存嵌入值到文件中
该函数从文件中读取已经保存的embbeding信息,转换为一个dict
- key为一个元组(filePath, lineCoverage)
- value为一个数组,把其余列存放至后面
def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]: df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}
1.5 save_doc_embbedings_to_csv
将嵌入值保存到csv文件中
这里处理了一下,不是讲整个嵌入值放到数组中,而是更具嵌入的维度放入到列中,不同的维度有不同的嵌入值
def save_doc_embeddings_to_csv(self, doc_embeddings: dict, df: pd.DataFrame, csv_filepath: str):# Get the dimensionality of the embedding vectors from the first element in the doc_embeddings dictionaryif len(doc_embeddings) == 0:returnEMBEDDING_DIM = len(list(doc_embeddings.values())[0])# Create a new dataframe with the filePath, lineCoverage, and embedding vector columnsembeddings_df = pd.DataFrame(columns=["filePath", "lineCoverage"] + [f"{i}" for i in range(EMBEDDING_DIM)])# Iterate over the rows in the original dataframefor idx, _ in df.iterrows():# Get the embedding vector for the current rowembedding = doc_embeddings[idx]# Create a new row in the embeddings dataframe with the filePath, lineCoverage, and embedding vector valuesrow = [idx[0], idx[1]] + embeddingembeddings_df.loc[len(embeddings_df)] = row# Save the embeddings dataframe to a CSV fileembeddings_df.to_csv(csv_filepath, index=False)
1.6 load_embeddings
加载嵌入值,从文件中
def load_embeddings(self, fname: str) -> dict[tuple[str, str], list[float]]: df = pd.read_csv(fname, header=0)max_dim = max([int(c) for c in df.columns if c != "filePath" and c != "lineCoverage"])return {(r.filePath, r.lineCoverage): [r[str(i)] for i in range(max_dim + 1)] for _, r in df.iterrows()}
2. embedding第二部分-获取代码相关性
获取相关的代码段,根据
- 任务描述
- 任务上下文
获取相关的代码,相似度最高的两个代码块
def get_relevant_code_chunks(self, task_description: str, task_context: str):query = task_description + "\n" + task_contextmost_relevant_document_sections = self.order_document_sections_by_query_similarity(query, self.document_embeddings)selected_chunks = []for _, section_index in most_relevant_document_sections:try:document_section = self.df.loc[section_index]selected_chunks.append(self.SEPARATOR + document_section['content'].replace("\n", " "))if len(selected_chunks) >= 2:breakexcept:passreturn selected_chunks
这个函数有两个参数,他会根据相似度对整个文件排序
- query 请求的文本,用作计算嵌入值
- context 上下文,用作查找和这段文本的相似度,就是上文的字典
def order_document_sections_by_query_similarity(self, query: str, contexts: dict[(str, str), np.array]) -> list[(float, (str, str))]:"""Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddingsto find the most relevant sections. Return the list of document sections, sorted by relevance in descending order."""query_embedding = self.get_query_embedding(query)document_similarities = sorted([(self.vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items()], reverse=True)return document_similarities
用数量积的方式计算两个向量之间的相似度,这里数量积可以表示为 a ⋅ b = ∣ a ∣ ∣ b ∣ c o s θ a{\cdot}b=|a||b|cos{\theta} a⋅b=∣a∣∣b∣cosθ,当两个向量垂直时,计算的值为0,计算的值越大说明相似度越高
def vector_similarity(self, x: list[float], y: list[float]) -> float:return np.dot(np.array(x), np.array(y))
3. 通过OpenAI获取嵌入值相关函数
计算这段文字的嵌入值
def get_query_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.QUERY_EMBEDDINGS_MODEL)
计算文档相关性
def get_doc_embedding(self, text: str) -> list[float]:return self.get_embedding(text, self.DOC_EMBEDDINGS_MODEL)
处理openAi返回
def get_embedding(self, text: str, model: str) -> list[float]:result = openai.Embedding.create(model=model,input=text)return result["data"][0]["embedding"]
下一篇文章,我们将进入主程序的阅读,看看embedding是如何和主程序结合的
这篇关于babyAGI(6)-babyCoder源码阅读4_Embbeding代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!