diff --git a/openai/embeddings_utils.py b/openai/embeddings_utils.py index 08fa94c2ea..f711d3e42a 100644 --- a/openai/embeddings_utils.py +++ b/openai/embeddings_utils.py @@ -15,22 +15,16 @@ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]: - - # replace newlines, which can negatively affect performance. - text = text.replace("\n", " ") +def get_embedding(text: str, engine="text-embedding-ada-002") -> List[float]: return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"] @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) async def aget_embedding( - text: str, engine="text-similarity-davinci-001" + text: str, engine="text-embedding-ada-002" ) -> List[float]: - # replace newlines, which can negatively affect performance. - text = text.replace("\n", " ") - return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][ "embedding" ] @@ -38,12 +32,9 @@ async def aget_embedding( @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embeddings( - list_of_text: List[str], engine="text-similarity-babbage-001" + list_of_text: List[str], engine="text-embedding-ada-002" ) -> List[List[float]]: - assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." - - # replace newlines, which can negatively affect performance. - list_of_text = [text.replace("\n", " ") for text in list_of_text] + assert len(list_of_text) <= 8191, "The batch size should not be larger than 8191." data = openai.Embedding.create(input=list_of_text, engine=engine).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. @@ -52,12 +43,9 @@ def get_embeddings( @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) async def aget_embeddings( - list_of_text: List[str], engine="text-similarity-babbage-001" + list_of_text: List[str], engine="text-embedding-ada-002" ) -> List[List[float]]: - assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." - - # replace newlines, which can negatively affect performance. - list_of_text = [text.replace("\n", " ") for text in list_of_text] + assert len(list_of_text) <= 8191, "The batch size should not be larger than 8191." data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.