1
1
import os
2
2
import sys
3
- import pandas as pd
4
- import numpy as np
3
+ from typing import Any , Callable , NamedTuple , Optional
5
4
6
- from typing import NamedTuple , Optional , Callable , Any
5
+ import numpy as np
6
+ import pandas as pd
7
7
8
8
9
9
class Remediation (NamedTuple ):
@@ -70,7 +70,7 @@ def lower_case_column_creator(df):
70
70
)
71
71
72
72
73
- def additional_column_validator (df ):
73
+ def additional_column_validator (df , fields = [ "prompt" , "completion" ] ):
74
74
"""
75
75
This validator will remove additional columns from the dataframe.
76
76
"""
@@ -79,9 +79,7 @@ def additional_column_validator(df):
79
79
immediate_msg = None
80
80
necessary_fn = None
81
81
if len (df .columns ) > 2 :
82
- additional_columns = [
83
- c for c in df .columns if c not in ["prompt" , "completion" ]
84
- ]
82
+ additional_columns = [c for c in df .columns if c not in fields ]
85
83
warn_message = ""
86
84
for ac in additional_columns :
87
85
dups = [c for c in additional_columns if ac in c ]
@@ -91,7 +89,7 @@ def additional_column_validator(df):
91
89
necessary_msg = f"Remove additional columns/keys: { additional_columns } "
92
90
93
91
def necessary_fn (x ):
94
- return x [[ "prompt" , "completion" ] ]
92
+ return x [fields ]
95
93
96
94
return Remediation (
97
95
name = "additional_column" ,
@@ -101,50 +99,47 @@ def necessary_fn(x):
101
99
)
102
100
103
101
104
- def non_empty_completion_validator (df ):
102
+ def non_empty_field_validator (df , field = "completion" ):
105
103
"""
106
104
This validator will ensure that no completion is empty.
107
105
"""
108
106
necessary_msg = None
109
107
necessary_fn = None
110
108
immediate_msg = None
111
109
112
- if (
113
- df ["completion" ].apply (lambda x : x == "" ).any ()
114
- or df ["completion" ].isnull ().any ()
115
- ):
116
- empty_rows = (df ["completion" ] == "" ) | (df ["completion" ].isnull ())
110
+ if df [field ].apply (lambda x : x == "" ).any () or df [field ].isnull ().any ():
111
+ empty_rows = (df [field ] == "" ) | (df [field ].isnull ())
117
112
empty_indexes = df .reset_index ().index [empty_rows ].tolist ()
118
- immediate_msg = f"\n - `completion ` column/key should not contain empty strings. These are rows: { empty_indexes } "
113
+ immediate_msg = f"\n - `{ field } ` column/key should not contain empty strings. These are rows: { empty_indexes } "
119
114
120
115
def necessary_fn (x ):
121
- return x [x ["completion" ] != "" ].dropna (subset = ["completion" ])
116
+ return x [x [field ] != "" ].dropna (subset = [field ])
122
117
123
- necessary_msg = f"Remove { len (empty_indexes )} rows with empty completions "
118
+ necessary_msg = f"Remove { len (empty_indexes )} rows with empty { field } s "
124
119
return Remediation (
125
- name = "empty_completion " ,
120
+ name = f"empty_ { field } " ,
126
121
immediate_msg = immediate_msg ,
127
122
necessary_msg = necessary_msg ,
128
123
necessary_fn = necessary_fn ,
129
124
)
130
125
131
126
132
- def duplicated_rows_validator (df ):
127
+ def duplicated_rows_validator (df , fields = [ "prompt" , "completion" ] ):
133
128
"""
134
129
This validator will suggest to the user to remove duplicate rows if they exist.
135
130
"""
136
- duplicated_rows = df .duplicated (subset = [ "prompt" , "completion" ] )
131
+ duplicated_rows = df .duplicated (subset = fields )
137
132
duplicated_indexes = df .reset_index ().index [duplicated_rows ].tolist ()
138
133
immediate_msg = None
139
134
optional_msg = None
140
135
optional_fn = None
141
136
142
137
if len (duplicated_indexes ) > 0 :
143
- immediate_msg = f"\n - There are { len (duplicated_indexes )} duplicated prompt-completion pairs . These are rows: { duplicated_indexes } "
138
+ immediate_msg = f"\n - There are { len (duplicated_indexes )} duplicated { '-' . join ( fields ) } sets . These are rows: { duplicated_indexes } "
144
139
optional_msg = f"Remove { len (duplicated_indexes )} duplicate rows"
145
140
146
141
def optional_fn (x ):
147
- return x .drop_duplicates (subset = [ "prompt" , "completion" ] )
142
+ return x .drop_duplicates (subset = fields )
148
143
149
144
return Remediation (
150
145
name = "duplicated_rows" ,
@@ -467,7 +462,7 @@ def lower_case(x):
467
462
)
468
463
469
464
470
- def read_any_format (fname ):
465
+ def read_any_format (fname , fields = [ "prompt" , "completion" ] ):
471
466
"""
472
467
This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
473
468
- for .xlsx it will read the first sheet
@@ -502,7 +497,7 @@ def read_any_format(fname):
502
497
content = f .read ()
503
498
df = pd .DataFrame (
504
499
[["" , line ] for line in content .split ("\n " )],
505
- columns = [ "prompt" , "completion" ] ,
500
+ columns = fields ,
506
501
dtype = str ,
507
502
)
508
503
if fname .lower ().endswith ("jsonl" ) or fname .lower ().endswith ("json" ):
@@ -623,7 +618,7 @@ def get_outfnames(fname, split):
623
618
while True :
624
619
index_suffix = f" ({ i } )" if i > 0 else ""
625
620
candidate_fnames = [
626
- fname . split ( "." )[0 ] + "_prepared" + suffix + index_suffix + ".jsonl"
621
+ os . path . splitext ( fname )[0 ] + "_prepared" + suffix + index_suffix + ".jsonl"
627
622
for suffix in suffixes
628
623
]
629
624
if not any (os .path .isfile (f ) for f in candidate_fnames ):
@@ -743,6 +738,30 @@ def write_out_file(df, fname, any_remediations, auto_accept):
743
738
sys .stdout .write ("Aborting... did not write the file\n " )
744
739
745
740
741
+ def write_out_search_file (df , fname , any_remediations , auto_accept , fields , purpose ):
742
+ """
743
+ This function will write out a dataframe to a file, if the user would like to proceed.
744
+ """
745
+ input_text = "\n \n Your data will be written to a new JSONL file. Proceed [Y/n]: "
746
+
747
+ if not any_remediations :
748
+ sys .stdout .write (
749
+ f'\n You can upload your file:\n > openai api files.create -f "{ fname } " -p { purpose } '
750
+ )
751
+
752
+ elif accept_suggestion (input_text , auto_accept ):
753
+ fnames = get_outfnames (fname , split = False )
754
+
755
+ assert len (fnames ) == 1
756
+ df [fields ].to_json (fnames [0 ], lines = True , orient = "records" , force_ascii = False )
757
+
758
+ sys .stdout .write (
759
+ f'\n Wrote modified file to { fnames [0 ]} `\n Feel free to take a look!\n \n Now upload that file:\n > openai api files.create -f "{ fnames [0 ]} " -p { purpose } '
760
+ )
761
+ else :
762
+ sys .stdout .write ("Aborting... did not write the file\n " )
763
+
764
+
746
765
def infer_task_type (df ):
747
766
"""
748
767
Infer the likely fine-tuning task type from the data
@@ -787,7 +806,7 @@ def get_validators():
787
806
lambda x : necessary_column_validator (x , "prompt" ),
788
807
lambda x : necessary_column_validator (x , "completion" ),
789
808
additional_column_validator ,
790
- non_empty_completion_validator ,
809
+ non_empty_field_validator ,
791
810
format_inferrer_validator ,
792
811
duplicated_rows_validator ,
793
812
long_examples_validator ,
@@ -799,3 +818,71 @@ def get_validators():
799
818
common_completion_suffix_validator ,
800
819
completions_space_start_validator ,
801
820
]
821
+
822
+
823
+ def get_search_validators (required_fields , optional_fields ):
824
+ validators = [
825
+ lambda x : necessary_column_validator (x , field ) for field in required_fields
826
+ ]
827
+ validators += [
828
+ lambda x : non_empty_field_validator (x , field ) for field in required_fields
829
+ ]
830
+ validators += [lambda x : duplicated_rows_validator (x , required_fields )]
831
+ validators += [
832
+ lambda x : additional_column_validator (
833
+ x , fields = required_fields + optional_fields
834
+ ),
835
+ ]
836
+
837
+ return validators
838
+
839
+
840
+ def apply_validators (
841
+ df ,
842
+ fname ,
843
+ remediation ,
844
+ validators ,
845
+ auto_accept ,
846
+ write_out_file_func ,
847
+ ):
848
+ optional_remediations = []
849
+ if remediation is not None :
850
+ optional_remediations .append (remediation )
851
+ for validator in validators :
852
+ remediation = validator (df )
853
+ if remediation is not None :
854
+ optional_remediations .append (remediation )
855
+ df = apply_necessary_remediation (df , remediation )
856
+
857
+ any_optional_or_necessary_remediations = any (
858
+ [
859
+ remediation
860
+ for remediation in optional_remediations
861
+ if remediation .optional_msg is not None
862
+ or remediation .necessary_msg is not None
863
+ ]
864
+ )
865
+ any_necessary_applied = any (
866
+ [
867
+ remediation
868
+ for remediation in optional_remediations
869
+ if remediation .necessary_msg is not None
870
+ ]
871
+ )
872
+ any_optional_applied = False
873
+
874
+ if any_optional_or_necessary_remediations :
875
+ sys .stdout .write (
876
+ "\n \n Based on the analysis we will perform the following actions:\n "
877
+ )
878
+ for remediation in optional_remediations :
879
+ df , optional_applied = apply_optional_remediation (
880
+ df , remediation , auto_accept
881
+ )
882
+ any_optional_applied = any_optional_applied or optional_applied
883
+ else :
884
+ sys .stdout .write ("\n \n No remediations found.\n " )
885
+
886
+ any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied
887
+
888
+ write_out_file_func (df , fname , any_optional_or_necessary_applied , auto_accept )
0 commit comments