diff --git a/dingo/exec/spark.py b/dingo/exec/spark.py index 653f637..5af13f8 100644 --- a/dingo/exec/spark.py +++ b/dingo/exec/spark.py @@ -28,85 +28,87 @@ def __init__(self, input_args: InputArgs, spark_rdd: RDD = None, spark_session: SparkSession = None, spark_conf: SparkConf = None): - # eval param + # Evaluation parameters self.llm: Optional[BaseLLM] = None self.group: Optional[Dict] = None self.summary: Optional[SummaryModel] = None self.bad_info_list: Optional[RDD] = None self.good_info_list: Optional[RDD] = None - # init param + # Initialization parameters self.input_args = input_args self.spark_rdd = spark_rdd self.spark_session = spark_session self.spark_conf = spark_conf + self._sc = None # SparkContext placeholder def __getstate__(self): + """Custom serialization to exclude non-serializable Spark objects.""" state = self.__dict__.copy() del state['spark_session'] del state['spark_rdd'] + del state['_sc'] return state def __setstate__(self, state): self.__dict__.update(state) - # def load_data(self) -> Generator[Any, None, None]: - # """ - # Reads data from given path. Returns generator of raw data. - # - # **Run in executor.** - # - # Returns: - # Generator[Any, None, None]: Generator of raw data. - # """ - # datasource_cls = datasource_map[self.input_args.datasource] - # dataset_cls = dataset_map["spark"] - # - # datasource: DataSource = datasource_cls(input_args=self.input_args) - # dataset: Dataset = dataset_cls(source=datasource) - # return dataset.get_data() + def _initialize_spark(self): + """Initialize Spark session if not already provided.""" + if self.spark_session is not None: + return self.spark_session, self.spark_session.sparkContext + elif self.spark_conf is not None: + spark = SparkSession.builder.config(conf=self.spark_conf).getOrCreate() + return spark, spark.sparkContext + else: + raise ValueError('Both spark_session and spark_conf are None. Please provide one.') def load_data(self) -> RDD: + """Load and return the RDD data.""" return self.spark_rdd def execute(self) -> List[SummaryModel]: + """Main execution method for Spark evaluation.""" create_time = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + # Initialize models and configuration Model.apply_config(self.input_args.custom_config, self.input_args.eval_group) self.group = Model.get_group(self.input_args.eval_group) + if GlobalConfig.config and GlobalConfig.config.llm_config: for llm_name in GlobalConfig.config.llm_config: self.llm = Model.get_llm(llm_name) - print("============= Init pyspark =============") - if self.spark_session is not None: - spark = self.spark_session - sc = spark.sparkContext - elif self.spark_conf is not None: - spark = SparkSession.builder.config(conf=self.spark_conf).getOrCreate() - sc = spark.sparkContext - else: - raise ValueError('[spark_session] and [spark_conf] is none. Please input.') + + print("============= Init PySpark =============") + spark, sc = self._initialize_spark() + self._sc = sc print("============== Init Done ===============") try: - # Exec Eval - # if self.spark_rdd is not None: - # data_rdd = self.spark_rdd - # else: - # data_rdd = sc.parallelize(self.load_data(), 3) + # Load and process data data_rdd = self.load_data() total = data_rdd.count() - data_info_list = data_rdd.map(self.evaluate) - bad_info_list = data_info_list.filter(lambda x: True if x['error_status'] else False) - bad_info_list.cache() - self.bad_info_list = bad_info_list + # Apply configuration for Spark driver + Model.apply_config_for_spark_driver(self.input_args.custom_config, self.input_args.eval_group) + + # Broadcast necessary objects to workers + broadcast_group = sc.broadcast(self.group) + broadcast_llm = sc.broadcast(self.llm) if self.llm else None + + # Evaluate data + data_info_list = data_rdd.map( + lambda x: self._evaluate_item(x, broadcast_group, broadcast_llm) + ).persist() # Cache the evaluated data for multiple uses + + # Filter and count bad/good items + self.bad_info_list = data_info_list.filter(lambda x: x['error_status']) + num_bad = self.bad_info_list.count() + if self.input_args.save_correct: - good_info_list = data_info_list.filter(lambda x: True if not x['error_status'] else False) - good_info_list.cache() - self.good_info_list = good_info_list + self.good_info_list = data_info_list.filter(lambda x: not x['error_status']) - num_bad = bad_info_list.count() - # calculate count + # Create summary self.summary = SummaryModel( task_id=str(uuid.uuid1()), task_name=self.input_args.task_name, @@ -114,154 +116,172 @@ def execute(self) -> List[SummaryModel]: input_path=self.input_args.input_path if not self.spark_rdd else '', output_path='', create_time=create_time, - score=0, - num_good=0, - num_bad=0, - total=0, + score=round((total - num_bad) / total * 100, 2) if total > 0 else 0, + num_good=total - num_bad, + num_bad=num_bad, + total=total, type_ratio={}, name_ratio={} ) - self.summary.total = total - self.summary.num_bad = num_bad - self.summary.num_good = total - num_bad - self.summary.score = round(self.summary.num_good / self.summary.total * 100, 2) + # Generate detailed summary + self._summarize_results() - self.summarize() self.summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + return [self.summary] + except Exception as e: raise e finally: if not self.input_args.save_data: - self.clean_context_and_session() + self._cleanup(spark) else: self.spark_session = spark - return [self.summary] - def evaluate(self, data_rdd_item) -> Dict[str, Any]: - Model.apply_config_for_spark_driver(self.input_args.custom_config, self.input_args.eval_group) - # eval with models ( Big Data Caution ) + def _evaluate_item(self, data_rdd_item, broadcast_group, broadcast_llm) -> Dict[str, Any]: + """Evaluate a single data item using broadcast variables.""" data: MetaData = data_rdd_item result_info = ResultInfo(data_id=data.data_id, prompt=data.prompt, content=data.content) + if self.input_args.save_raw: result_info.raw_data = data.raw_data + + group = broadcast_group.value + llm = broadcast_llm.value if broadcast_llm else None + bad_type_list = [] good_type_list = [] bad_name_list = [] good_name_list = [] bad_reason_list = [] good_reason_list = [] - for group_type, group in self.group.items(): + + for group_type, group_items in group.items(): if group_type == 'rule': - r_i = self.evaluate_rule(group, data) + r_i = self._evaluate_rule(group_items, data) elif group_type == 'prompt': - r_i = self.evaluate_prompt(group, data) + r_i = self._evaluate_prompt(group_items, data, llm) else: raise RuntimeError(f'Unsupported group type: {group_type}') + if r_i.error_status: result_info.error_status = True - bad_type_list = bad_type_list + r_i.type_list - bad_name_list = bad_name_list + r_i.name_list - bad_reason_list = bad_reason_list + r_i.reason_list + bad_type_list.extend(r_i.type_list) + bad_name_list.extend(r_i.name_list) + bad_reason_list.extend(r_i.reason_list) else: - good_type_list = good_type_list + r_i.type_list - good_name_list = good_name_list + r_i.name_list - good_reason_list = good_reason_list + r_i.reason_list - if result_info.error_status: - result_info.type_list = list(set(bad_type_list)) - for name in bad_name_list: - if name not in result_info.name_list: - result_info.name_list.append(name) - for reason in bad_reason_list: - if reason and reason not in result_info.reason_list: - result_info.reason_list.append(reason) - else: - result_info.type_list = list(set(good_type_list)) - for name in good_name_list: - if name not in result_info.name_list: - result_info.name_list.append(name) - for reason in good_reason_list: - if reason and reason not in result_info.reason_list: - result_info.reason_list.append(reason) + good_type_list.extend(r_i.type_list) + good_name_list.extend(r_i.name_list) + good_reason_list.extend(r_i.reason_list) + + # Process results + target_list = bad_type_list if result_info.error_status else good_type_list + result_info.type_list = list(set(target_list)) + + target_names = bad_name_list if result_info.error_status else good_name_list + result_info.name_list = list(dict.fromkeys(target_names)) # Preserve order while removing duplicates + + target_reasons = bad_reason_list if result_info.error_status else good_reason_list + result_info.reason_list = [r for r in target_reasons if r] # Filter out None/empty reasons return result_info.to_dict() - def evaluate_rule(self, group: List[BaseRule], d: MetaData) -> ResultInfo: - result_info = ResultInfo(data_id=d.data_id, prompt=d.prompt, content=d.content) - log.debug("[RuleGroup]: " + str(group)) + def _evaluate_rule(self, group: List[BaseRule], data: MetaData) -> ResultInfo: + """Evaluate data against a group of rules.""" + result_info = ResultInfo(data_id=data.data_id, prompt=data.prompt, content=data.content) + bad_type_list = [] good_type_list = [] bad_name_list = [] good_name_list = [] bad_reason_list = [] good_reason_list = [] - for r in group: - # execute rule - tmp: ModelRes = r.eval(d) - # analyze result - if tmp.error_status: + + for rule in group: + res: ModelRes = rule.eval(data) + + if res.error_status: result_info.error_status = True - bad_type_list.append(tmp.type) - bad_name_list.append(tmp.type + '-' + tmp.name) - bad_reason_list.extend(tmp.reason) + bad_type_list.append(res.type) + bad_name_list.append(f"{res.type}-{res.name}") + bad_reason_list.extend(res.reason) else: - good_type_list.append(tmp.type) - good_name_list.append(tmp.type + '-' + tmp.name) - good_reason_list.extend(tmp.reason) - if result_info.error_status: - result_info.type_list = list(set(bad_type_list)) - result_info.name_list = bad_name_list - result_info.reason_list = bad_reason_list - else: - result_info.type_list = list(set(good_type_list)) - result_info.name_list = good_name_list - result_info.reason_list = good_reason_list + good_type_list.append(res.type) + good_name_list.append(f"{res.type}-{res.name}") + good_reason_list.extend(res.reason) + + # Set results + target_list = bad_type_list if result_info.error_status else good_type_list + result_info.type_list = list(set(target_list)) + result_info.name_list = bad_name_list if result_info.error_status else good_name_list + result_info.reason_list = bad_reason_list if result_info.error_status else good_reason_list + return result_info - def evaluate_prompt(self, group: List[BasePrompt], d: MetaData) -> ResultInfo: - result_info = ResultInfo(data_id=d.data_id, prompt=d.prompt, content=d.content) - log.debug("[PromptGroup]: " + str(group)) + def _evaluate_prompt(self, group: List[BasePrompt], data: MetaData, llm: BaseLLM) -> ResultInfo: + """Evaluate data against a group of prompts using LLM.""" + if llm is None: + raise ValueError("LLM is required for prompt evaluation") + + result_info = ResultInfo(data_id=data.data_id, prompt=data.prompt, content=data.content) + bad_type_list = [] good_type_list = [] bad_name_list = [] good_name_list = [] bad_reason_list = [] good_reason_list = [] - for p in group: - self.llm.set_prompt(p) - # execute prompt - tmp: ModelRes = self.llm.call_api(d) - # analyze result - if tmp.error_status: + + for prompt in group: + llm.set_prompt(prompt) + res: ModelRes = llm.call_api(data) + + if res.error_status: result_info.error_status = True - bad_type_list.append(tmp.type) - bad_name_list.append(tmp.type + '-' + tmp.name) - bad_reason_list.extend(tmp.reason) + bad_type_list.append(res.type) + bad_name_list.append(f"{res.type}-{res.name}") + bad_reason_list.extend(res.reason) else: - good_type_list.append(tmp.type) - good_name_list.append(tmp.type + '-' + tmp.name) - good_reason_list.extend(tmp.reason) - if result_info.error_status: - result_info.type_list = list(set(bad_type_list)) - result_info.name_list = bad_name_list - result_info.reason_list = bad_reason_list - else: - result_info.type_list = list(set(good_type_list)) - result_info.name_list = good_name_list - result_info.reason_list = good_reason_list + good_type_list.append(res.type) + good_name_list.append(f"{res.type}-{res.name}") + good_reason_list.extend(res.reason) + + # Set results + target_list = bad_type_list if result_info.error_status else good_type_list + result_info.type_list = list(set(target_list)) + result_info.name_list = bad_name_list if result_info.error_status else good_name_list + result_info.reason_list = bad_reason_list if result_info.error_status else good_reason_list + return result_info - def summarize(self): - list_rdd = self.bad_info_list.flatMap(lambda row: row['type_list']) - unique_list = list_rdd.distinct().collect() - for metric_type in unique_list: - num = self.bad_info_list.filter(lambda x: metric_type in x['type_list']).count() - self.summary.type_ratio[metric_type] = round(num / self.summary.total, 6) + def _summarize_results(self): + """Generate summary statistics from bad info list.""" + if not self.bad_info_list: + return + + # Calculate type ratios + type_counts = ( + self.bad_info_list + .flatMap(lambda x: [(t, 1) for t in x['type_list']]) + .reduceByKey(lambda a, b: a + b) + .collectAsMap() + ) + self.summary.type_ratio = { + k: round(v / self.summary.total, 6) + for k, v in type_counts.items() + } - list_rdd = self.bad_info_list.flatMap(lambda row: row['name_list']) - unique_list = list_rdd.distinct().collect() - for name in unique_list: - num = self.bad_info_list.filter(lambda x: name in x['name_list']).count() - self.summary.name_ratio[name] = round(num / self.summary.total, 6) + # Calculate name ratios + name_counts = ( + self.bad_info_list + .flatMap(lambda x: [(n, 1) for n in x['name_list']]) + .reduceByKey(lambda a, b: a + b) + .collectAsMap() + ) + self.summary.name_ratio = { + k: round(v / self.summary.total, 6) + for k, v in name_counts.items() + } self.summary.type_ratio = dict(sorted(self.summary.type_ratio.items())) self.summary.name_ratio = dict(sorted(self.summary.name_ratio.items())) @@ -270,17 +290,40 @@ def get_summary(self): return self.summary def get_bad_info_list(self): + if self.input_args.save_raw: + return self.bad_info_list.map(lambda x: { + **x['raw_data'], + 'dingo_result': { + 'error_status': x['error_status'], + 'type_list': x['type_list'], + 'name_list': x['name_list'], + 'reason_list': x['reason_list'] + } + }) return self.bad_info_list def get_good_info_list(self): + if self.input_args.save_raw: + return self.good_info_list.map(lambda x: { + **x['raw_data'], + 'dingo_result': { + 'error_status': x['error_status'], + 'type_list': x['type_list'], + 'name_list': x['name_list'], + 'reason_list': x['reason_list'] + } + }) return self.good_info_list def save_data(self, start_time): + """Save output data to specified path.""" output_path = os.path.join(self.input_args.output_path, start_time) model_path = os.path.join(output_path, self.input_args.eval_group) - if not os.path.exists(model_path): - os.makedirs(model_path) + os.makedirs(model_path, exist_ok=True) - def clean_context_and_session(self): - self.spark_session.stop() - self.spark_session.sparkContext.stop() + def _cleanup(self, spark): + """Clean up Spark resources.""" + if spark: + spark.stop() + if spark.sparkContext: + spark.sparkContext.stop()