From ef81fd039071e0928b7c9adbacd4e754ca0eaffa Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Tue, 28 Nov 2023 23:32:18 +0000 Subject: [PATCH] fix: update the llm+kmeans notebook with recent change --- .../bq_dataframes_llm_kmeans.ipynb | 47 +++++-------------- 1 file changed, 12 insertions(+), 35 deletions(-) diff --git a/notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb b/notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb index 8d75950925..5f74046fc0 100644 --- a/notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb +++ b/notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb @@ -366,18 +366,6 @@ "predicted_embeddings.head() " ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4H_etYfsEOFP" - }, - "outputs": [], - "source": [ - "# Join the complaints with their embeddings in the same DataFrame\n", - "combined_df = downsampled_issues_df.join(predicted_embeddings)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -426,30 +414,19 @@ "outputs": [], "source": [ "# Use KMeans clustering to calculate our groups. Will take ~3 minutes.\n", - "cluster_model.fit(combined_df[[\"text_embedding\"]])\n", - "clustered_result = cluster_model.predict(combined_df[[\"text_embedding\"]])\n", + "cluster_model.fit(predicted_embeddings[[\"text_embedding\"]])\n", + "clustered_result = cluster_model.predict(predicted_embeddings)\n", "# Notice the CENTROID_ID column, which is the ID number of the group that\n", "# each complaint belongs to.\n", "clustered_result.head(n=5)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Join the group number to the complaints and their text embeddings\n", - "combined_clustered_result = combined_df.join(clustered_result)\n", - "combined_clustered_result.head(n=5) " - ] - }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "Our dataframe combined_clustered_result now has three columns: the complaints, their text embeddings, and an ID from 1-10 (inclusive) indicating which semantically similar group they belong to." + "Our dataframe combined_clustered_result now has three complaint columns: the content, their text embeddings, and an ID from 1-10 (inclusive) indicating which semantically similar group they belong to." ] }, { @@ -480,14 +457,14 @@ "source": [ "# Using bigframes, with syntax identical to pandas,\n", "# filter out the first and second groups\n", - "cluster_1_result = combined_clustered_result[\n", - " combined_clustered_result[\"CENTROID_ID\"] == 1\n", - "][[\"consumer_complaint_narrative\"]]\n", + "cluster_1_result = clustered_result[\n", + " clustered_result[\"CENTROID_ID\"] == 1\n", + "][[\"content\"]]\n", "cluster_1_result_pandas = cluster_1_result.head(5).to_pandas()\n", "\n", - "cluster_2_result = combined_clustered_result[\n", - " combined_clustered_result[\"CENTROID_ID\"] == 2\n", - "][[\"consumer_complaint_narrative\"]]\n", + "cluster_2_result = clustered_result[\n", + " clustered_result[\"CENTROID_ID\"] == 2\n", + "][[\"content\"]]\n", "cluster_2_result_pandas = cluster_2_result.head(5).to_pandas()" ] }, @@ -503,15 +480,15 @@ "prompt1 = 'comment list 1:\\n'\n", "for i in range(5):\n", " prompt1 += str(i + 1) + '. ' + \\\n", - " cluster_1_result_pandas[\"consumer_complaint_narrative\"].iloc[i] + '\\n'\n", + " cluster_1_result_pandas[\"content\"].iloc[i] + '\\n'\n", "\n", "prompt2 = 'comment list 2:\\n'\n", "for i in range(5):\n", " prompt2 += str(i + 1) + '. ' + \\\n", - " cluster_2_result_pandas[\"consumer_complaint_narrative\"].iloc[i] + '\\n'\n", + " cluster_2_result_pandas[\"content\"].iloc[i] + '\\n'\n", "\n", "print(prompt1)\n", - "print(prompt2)\n" + "print(prompt2)" ] }, {