Open AI 利用语义嵌入(Embedding)来做一个知识库

要使用 OpenAI 的 Embedding API 创建一个带有数据库的系统,首先需要通过 API 调用获取文本的语义嵌入(Embedding),然后将这些嵌入存储在数据库中,供后续检索和比较使用。你可以使用 PostgreSQL 或 SQLite 等数据库,以及 Python 来实现这个系统。

以下是整个流程及代码示例:

流程步骤:

    • 安装 OpenAI 的 Python SDK 和数据库库(例如 psycopg2sqlite3):
  1. 设置 OpenAI API 密钥:
    • 获取 OpenAI 的 API 密钥并配置环境变量,或者在代码中直接使用密钥。
  2. 创建数据库:
    • 在数据库中创建一个表,用于存储文本和生成的嵌入(向量)。
  3. API 调用生成嵌入:
    • 调用 OpenAI 的 Embedding API 来生成语义嵌入。
  4. 存储嵌入到数据库:
    • 将文本及其对应的嵌入向量存储在数据库中。
  5. 检索和比较:
    • 使用嵌入检索,比较输入文本的嵌入与存储在数据库中的嵌入,找到最相似的文本。

安装依赖:

pip install openai psycopg2-binary  # For PostgreSQL

代码实现

1. 数据库初始化

import psycopg2

# 初始化数据库连接(以 PostgreSQL 为例)
def init_db():
    conn = psycopg2.connect(
        host="localhost",
        database="embeddings_db",
        user="your_user",
        password="your_password"
    )
    return conn

# 创建表格存储文本和嵌入
def create_table():
    conn = init_db()
    cur = conn.cursor()
    cur.execute("""
        CREATE TABLE IF NOT EXISTS embeddings (
            id SERIAL PRIMARY KEY,
            text TEXT NOT NULL,
            embedding FLOAT8[] NOT NULL
        );
    """)
    conn.commit()
    cur.close()
    conn.close()

# 初始化数据库
create_table()

2. 调用 OpenAI API 生成嵌入

import openai

# 设置 OpenAI API 密钥
openai.api_key = "your_openai_api_key"

# 调用 OpenAI API 获取嵌入
def get_embedding(text):
    response = openai.Embedding.create(
        model="text-embedding-ada-002",
        input=text
    )
    embedding = response["data"][0]["embedding"]
    return embedding

3. 插入文本及其嵌入到数据库

def insert_embedding(text, embedding):
    conn = init_db()
    cur = conn.cursor()
    cur.execute("""
        INSERT INTO embeddings (text, embedding) 
        VALUES (%s, %s);
    """, (text, embedding))
    conn.commit()
    cur.close()
    conn.close()

# 插入示例文本和嵌入
text = "This is an example sentence for embedding."
embedding = get_embedding(text)
insert_embedding(text, embedding)

4. 从数据库中检索并比较嵌入

为了比较嵌入的相似性,通常会计算两个向量之间的余弦相似度。以下是如何实现文本检索和相似度计算:

import numpy as np

# 计算余弦相似度
def cosine_similarity(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

# 从数据库中检索最相似的文本
def find_similar_text(input_text):
    input_embedding = get_embedding(input_text)
    
    conn = init_db()
    cur = conn.cursor()
    
    cur.execute("SELECT text, embedding FROM embeddings;")
    rows = cur.fetchall()
    
    max_similarity = -1
    most_similar_text = None
    
    for row in rows:
        db_text, db_embedding = row
        db_embedding = np.array(db_embedding, dtype=np.float32)
        
        similarity = cosine_similarity(input_embedding, db_embedding)
        if similarity > max_similarity:
            max_similarity = similarity
            most_similar_text = db_text
    
    cur.close()
    conn.close()
    
    return most_similar_text, max_similarity

# 查询相似文本
input_text = "Find the most similar sentence to this one."
similar_text, similarity = find_similar_text(input_text)
print(f"Most similar text: {similar_text} with similarity score: {similarity}")

5. 启动 API(可选)

如果你希望通过 API 调用这一功能,可以使用 FastAPI 创建 Web 服务:

from fastapi import FastAPI

app = FastAPI()

@app.post("/get_similar_text/")
def get_similar(input_text: str):
    similar_text, similarity = find_similar_text(input_text)
    return {"similar_text": similar_text, "similarity": similarity}

# 运行 FastAPI 服务器
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

流程总结:

  1. 获取 API 密钥并设置。
  2. 创建数据库并存储文本和嵌入向量。
  3. 生成嵌入并将其存储到数据库中。
  4. 实现相似度计算(余弦相似度)用于文本检索。
  5. 可选:通过 API 服务进行远程访问。

这个系统可以扩展为更复杂的知识库、文档检索系统或语义搜索引擎。