Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c902a55

Browse files
committed
Validate search files (openai#69)
* Add validators for search files * Clean up fields
1 parent 205d063 commit c902a55

File tree

2 files changed

+182
-69
lines changed

2 files changed

+182
-69
lines changed

openai/cli.py

+69-43
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
import signal
44
import sys
55
import warnings
6+
from functools import partial
67

78
import openai
89
from openai.validators import (
910
apply_necessary_remediation,
10-
apply_optional_remediation,
11+
apply_validators,
12+
get_search_validators,
1113
get_validators,
1214
read_any_format,
1315
write_out_file,
16+
write_out_search_file,
1417
)
1518

1619

@@ -224,6 +227,41 @@ def list(cls, args):
224227

225228

226229
class Search:
230+
@classmethod
231+
def prepare_data(cls, args):
232+
233+
sys.stdout.write("Analyzing...\n")
234+
fname = args.file
235+
auto_accept = args.quiet
236+
purpose = args.purpose
237+
238+
optional_fields = ["metadata"]
239+
240+
if purpose == "classifications":
241+
required_fields = ["text", "labels"]
242+
else:
243+
required_fields = ["text"]
244+
245+
df, remediation = read_any_format(
246+
fname, fields=required_fields + optional_fields
247+
)
248+
249+
if "metadata" not in df:
250+
df["metadata"] = None
251+
252+
apply_necessary_remediation(None, remediation)
253+
validators = get_search_validators(required_fields, optional_fields)
254+
255+
write_out_file_func = partial(
256+
write_out_search_file,
257+
purpose=purpose,
258+
fields=required_fields + optional_fields,
259+
)
260+
261+
apply_validators(
262+
df, fname, remediation, validators, auto_accept, write_out_file_func
263+
)
264+
227265
@classmethod
228266
def create_alpha(cls, args):
229267
resp = openai.Search.create_alpha(
@@ -436,49 +474,14 @@ def prepare_data(cls, args):
436474

437475
validators = get_validators()
438476

439-
optional_remediations = []
440-
if remediation is not None:
441-
optional_remediations.append(remediation)
442-
for validator in validators:
443-
remediation = validator(df)
444-
if remediation is not None:
445-
optional_remediations.append(remediation)
446-
df = apply_necessary_remediation(df, remediation)
447-
448-
any_optional_or_necessary_remediations = any(
449-
[
450-
remediation
451-
for remediation in optional_remediations
452-
if remediation.optional_msg is not None
453-
or remediation.necessary_msg is not None
454-
]
477+
apply_validators(
478+
df,
479+
fname,
480+
remediation,
481+
validators,
482+
auto_accept,
483+
write_out_file_func=write_out_file,
455484
)
456-
any_necessary_applied = any(
457-
[
458-
remediation
459-
for remediation in optional_remediations
460-
if remediation.necessary_msg is not None
461-
]
462-
)
463-
any_optional_applied = False
464-
465-
if any_optional_or_necessary_remediations:
466-
sys.stdout.write(
467-
"\n\nBased on the analysis we will perform the following actions:\n"
468-
)
469-
for remediation in optional_remediations:
470-
df, optional_applied = apply_optional_remediation(
471-
df, remediation, auto_accept
472-
)
473-
any_optional_applied = any_optional_applied or optional_applied
474-
else:
475-
sys.stdout.write("\n\nNo remediations found.\n")
476-
477-
any_optional_or_necessary_applied = (
478-
any_optional_applied or any_necessary_applied
479-
)
480-
481-
write_out_file(df, fname, any_optional_or_necessary_applied, auto_accept)
482485

483486

484487
def tools_register(parser):
@@ -508,6 +511,29 @@ def help(args):
508511
)
509512
sub.set_defaults(func=FineTune.prepare_data)
510513

514+
sub = subparsers.add_parser("search.prepare_data")
515+
sub.add_argument(
516+
"-f",
517+
"--file",
518+
required=True,
519+
help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed."
520+
"This should be the local file path.",
521+
)
522+
sub.add_argument(
523+
"-p",
524+
"--purpose",
525+
help="Why are you uploading this file? (see https://beta.openai.com/docs/api-reference/ for purposes)",
526+
required=True,
527+
)
528+
sub.add_argument(
529+
"-q",
530+
"--quiet",
531+
required=False,
532+
action="store_true",
533+
help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
534+
)
535+
sub.set_defaults(func=Search.prepare_data)
536+
511537

512538
def api_register(parser):
513539
# Engine management

openai/validators.py

+113-26
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
import sys
3-
import pandas as pd
4-
import numpy as np
3+
from typing import Any, Callable, NamedTuple, Optional
54

6-
from typing import NamedTuple, Optional, Callable, Any
5+
import numpy as np
6+
import pandas as pd
77

88

99
class Remediation(NamedTuple):
@@ -70,7 +70,7 @@ def lower_case_column_creator(df):
7070
)
7171

7272

73-
def additional_column_validator(df):
73+
def additional_column_validator(df, fields=["prompt", "completion"]):
7474
"""
7575
This validator will remove additional columns from the dataframe.
7676
"""
@@ -79,9 +79,7 @@ def additional_column_validator(df):
7979
immediate_msg = None
8080
necessary_fn = None
8181
if len(df.columns) > 2:
82-
additional_columns = [
83-
c for c in df.columns if c not in ["prompt", "completion"]
84-
]
82+
additional_columns = [c for c in df.columns if c not in fields]
8583
warn_message = ""
8684
for ac in additional_columns:
8785
dups = [c for c in additional_columns if ac in c]
@@ -91,7 +89,7 @@ def additional_column_validator(df):
9189
necessary_msg = f"Remove additional columns/keys: {additional_columns}"
9290

9391
def necessary_fn(x):
94-
return x[["prompt", "completion"]]
92+
return x[fields]
9593

9694
return Remediation(
9795
name="additional_column",
@@ -101,50 +99,47 @@ def necessary_fn(x):
10199
)
102100

103101

104-
def non_empty_completion_validator(df):
102+
def non_empty_field_validator(df, field="completion"):
105103
"""
106104
This validator will ensure that no completion is empty.
107105
"""
108106
necessary_msg = None
109107
necessary_fn = None
110108
immediate_msg = None
111109

112-
if (
113-
df["completion"].apply(lambda x: x == "").any()
114-
or df["completion"].isnull().any()
115-
):
116-
empty_rows = (df["completion"] == "") | (df["completion"].isnull())
110+
if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
111+
empty_rows = (df[field] == "") | (df[field].isnull())
117112
empty_indexes = df.reset_index().index[empty_rows].tolist()
118-
immediate_msg = f"\n- `completion` column/key should not contain empty strings. These are rows: {empty_indexes}"
113+
immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"
119114

120115
def necessary_fn(x):
121-
return x[x["completion"] != ""].dropna(subset=["completion"])
116+
return x[x[field] != ""].dropna(subset=[field])
122117

123-
necessary_msg = f"Remove {len(empty_indexes)} rows with empty completions"
118+
necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
124119
return Remediation(
125-
name="empty_completion",
120+
name=f"empty_{field}",
126121
immediate_msg=immediate_msg,
127122
necessary_msg=necessary_msg,
128123
necessary_fn=necessary_fn,
129124
)
130125

131126

132-
def duplicated_rows_validator(df):
127+
def duplicated_rows_validator(df, fields=["prompt", "completion"]):
133128
"""
134129
This validator will suggest to the user to remove duplicate rows if they exist.
135130
"""
136-
duplicated_rows = df.duplicated(subset=["prompt", "completion"])
131+
duplicated_rows = df.duplicated(subset=fields)
137132
duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
138133
immediate_msg = None
139134
optional_msg = None
140135
optional_fn = None
141136

142137
if len(duplicated_indexes) > 0:
143-
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated prompt-completion pairs. These are rows: {duplicated_indexes}"
138+
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
144139
optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"
145140

146141
def optional_fn(x):
147-
return x.drop_duplicates(subset=["prompt", "completion"])
142+
return x.drop_duplicates(subset=fields)
148143

149144
return Remediation(
150145
name="duplicated_rows",
@@ -467,7 +462,7 @@ def lower_case(x):
467462
)
468463

469464

470-
def read_any_format(fname):
465+
def read_any_format(fname, fields=["prompt", "completion"]):
471466
"""
472467
This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
473468
- for .xlsx it will read the first sheet
@@ -502,7 +497,7 @@ def read_any_format(fname):
502497
content = f.read()
503498
df = pd.DataFrame(
504499
[["", line] for line in content.split("\n")],
505-
columns=["prompt", "completion"],
500+
columns=fields,
506501
dtype=str,
507502
)
508503
if fname.lower().endswith("jsonl") or fname.lower().endswith("json"):
@@ -623,7 +618,7 @@ def get_outfnames(fname, split):
623618
while True:
624619
index_suffix = f" ({i})" if i > 0 else ""
625620
candidate_fnames = [
626-
fname.split(".")[0] + "_prepared" + suffix + index_suffix + ".jsonl"
621+
os.path.splitext(fname)[0] + "_prepared" + suffix + index_suffix + ".jsonl"
627622
for suffix in suffixes
628623
]
629624
if not any(os.path.isfile(f) for f in candidate_fnames):
@@ -743,6 +738,30 @@ def write_out_file(df, fname, any_remediations, auto_accept):
743738
sys.stdout.write("Aborting... did not write the file\n")
744739

745740

741+
def write_out_search_file(df, fname, any_remediations, auto_accept, fields, purpose):
742+
"""
743+
This function will write out a dataframe to a file, if the user would like to proceed.
744+
"""
745+
input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
746+
747+
if not any_remediations:
748+
sys.stdout.write(
749+
f'\nYou can upload your file:\n> openai api files.create -f "{fname}" -p {purpose}'
750+
)
751+
752+
elif accept_suggestion(input_text, auto_accept):
753+
fnames = get_outfnames(fname, split=False)
754+
755+
assert len(fnames) == 1
756+
df[fields].to_json(fnames[0], lines=True, orient="records", force_ascii=False)
757+
758+
sys.stdout.write(
759+
f'\nWrote modified file to {fnames[0]}`\nFeel free to take a look!\n\nNow upload that file:\n> openai api files.create -f "{fnames[0]}" -p {purpose}'
760+
)
761+
else:
762+
sys.stdout.write("Aborting... did not write the file\n")
763+
764+
746765
def infer_task_type(df):
747766
"""
748767
Infer the likely fine-tuning task type from the data
@@ -787,7 +806,7 @@ def get_validators():
787806
lambda x: necessary_column_validator(x, "prompt"),
788807
lambda x: necessary_column_validator(x, "completion"),
789808
additional_column_validator,
790-
non_empty_completion_validator,
809+
non_empty_field_validator,
791810
format_inferrer_validator,
792811
duplicated_rows_validator,
793812
long_examples_validator,
@@ -799,3 +818,71 @@ def get_validators():
799818
common_completion_suffix_validator,
800819
completions_space_start_validator,
801820
]
821+
822+
823+
def get_search_validators(required_fields, optional_fields):
824+
validators = [
825+
lambda x: necessary_column_validator(x, field) for field in required_fields
826+
]
827+
validators += [
828+
lambda x: non_empty_field_validator(x, field) for field in required_fields
829+
]
830+
validators += [lambda x: duplicated_rows_validator(x, required_fields)]
831+
validators += [
832+
lambda x: additional_column_validator(
833+
x, fields=required_fields + optional_fields
834+
),
835+
]
836+
837+
return validators
838+
839+
840+
def apply_validators(
841+
df,
842+
fname,
843+
remediation,
844+
validators,
845+
auto_accept,
846+
write_out_file_func,
847+
):
848+
optional_remediations = []
849+
if remediation is not None:
850+
optional_remediations.append(remediation)
851+
for validator in validators:
852+
remediation = validator(df)
853+
if remediation is not None:
854+
optional_remediations.append(remediation)
855+
df = apply_necessary_remediation(df, remediation)
856+
857+
any_optional_or_necessary_remediations = any(
858+
[
859+
remediation
860+
for remediation in optional_remediations
861+
if remediation.optional_msg is not None
862+
or remediation.necessary_msg is not None
863+
]
864+
)
865+
any_necessary_applied = any(
866+
[
867+
remediation
868+
for remediation in optional_remediations
869+
if remediation.necessary_msg is not None
870+
]
871+
)
872+
any_optional_applied = False
873+
874+
if any_optional_or_necessary_remediations:
875+
sys.stdout.write(
876+
"\n\nBased on the analysis we will perform the following actions:\n"
877+
)
878+
for remediation in optional_remediations:
879+
df, optional_applied = apply_optional_remediation(
880+
df, remediation, auto_accept
881+
)
882+
any_optional_applied = any_optional_applied or optional_applied
883+
else:
884+
sys.stdout.write("\n\nNo remediations found.\n")
885+
886+
any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied
887+
888+
write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)

0 commit comments

Comments
 (0)