diff --git a/ml/digest-score/app.py b/ml/digest-score/app.py index c6af8e147..0f4111937 100644 --- a/ml/digest-score/app.py +++ b/ml/digest-score/app.py @@ -18,6 +18,24 @@ import dateutil.parser from google.cloud import storage from features.user_history import FEATURE_COLUMNS +import concurrent.futures +from threading import Lock, RLock +from collections import ChainMap +import copy + +class ThreadSafeUserFeatures: + def __init__(self): + self._data = {} + self._lock = RLock() + + def get(self): + with self._lock: + return dict(self._data) + + def update(self, new_features): + with self._lock: + self._data.update(new_features) + app = Flask(__name__) logging.basicConfig(level=logging.INFO, stream=sys.stdout) @@ -26,7 +44,8 @@ USER_HISTORY_PATH = 'user_features.pkl' MODEL_PIPELINE_PATH = 'predict_read_model-v003.pkl' pipeline = None -user_features = None +user_features_store = ThreadSafeUserFeatures() + # these buckets are used for reporting scores, we want to make sure # there is decent diversity in the returned scores. @@ -93,23 +112,23 @@ def predict_proba_wrapper(X): def refresh_data(): start = timer() global pipeline - global explainer - global user_features if os.getenv('LOAD_LOCAL_MODEL') == None: - print(f"loading data from {os.getenv('GCS_BUCKET')}") + app.logger.info(f"loading data from {os.getenv('GCS_BUCKET')}") gcs_bucket_name = os.getenv('GCS_BUCKET') download_from_gcs(gcs_bucket_name, f'data/features/{USER_HISTORY_PATH}', USER_HISTORY_PATH) download_from_gcs(gcs_bucket_name, f'data/models/{MODEL_PIPELINE_PATH}', MODEL_PIPELINE_PATH) pipeline = load_pipeline(MODEL_PIPELINE_PATH) - user_features = load_user_features(USER_HISTORY_PATH) - end = timer() - print('time to refresh data (in seconds):', end - start) - print('loaded pipeline:', pipeline) - print('loaded number of user_features:', len(user_features)) + + new_features = load_user_features(USER_HISTORY_PATH) + user_features_store.update(new_features) + + app.logger.info(f'time to refresh data (in seconds): {timer() - start}') + app.logger.info(f'loaded pipeline: {pipeline}') + app.logger.info(f'loaded number of user_features: {len(new_features)}') -def compute_score(user_id, item_features): - interaction_score = compute_interaction_score(user_id, item_features) +def compute_score(user_id, item_features, user_features): + interaction_score = compute_interaction_score(user_id, item_features, user_features) observe_score(interaction_score) return { 'score': interaction_score, @@ -117,7 +136,7 @@ def compute_score(user_id, item_features): } -def compute_interaction_score(user_id, item_features): +def compute_interaction_score(user_id, item_features, user_features): start = timer() original_url_host = urlparse(item_features.get('original_url')).netloc df_test = pd.DataFrame([{ @@ -160,15 +179,36 @@ def compute_interaction_score(user_id, item_features): df_test = df_test.fillna(0) df_predict = df_test[FEATURE_COLUMNS] - end = timer() - print('time to compute score (in seconds):', end - start) interaction_score = pipeline.predict_proba(df_predict) - print("INTERACTION SCORE: ", interaction_score) - print('item_features:\n', df_predict[df_predict != 0].stack()) + app.logger.info(f'INTERACTION SCORE: {interaction_score}') + app.logger.info(f'item_features:\n{df_predict[df_predict != 0].stack()}') + app.logger.info(f'time to compute score (in seconds): {timer() - start}') return np.float64(interaction_score[0][1]) +def process_parallel_item(user_id, key, item, user_features): + library_item_id = item['library_item_id'] + return library_item_id, compute_score(user_id, item, user_features) + +def parallel_compute_scores(user_id, items, max_workers=None): + user_features = user_features_store.get() + result = {} + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_item = {executor.submit(process_parallel_item, user_id, key, item, user_features): (key, item) + for key, item in items.items()} + + for future in concurrent.futures.as_completed(future_to_item): + key, item = future_to_item[future] + try: + library_item_id, score = future.result() + result[library_item_id] = score + except Exception as exc: + app.logger.error(f'Item {key} generated an exception: {exc}') + return result + + + @app.route('/_ah/health', methods=['GET']) def ready(): return jsonify({'OK': 'yes'}), 200 @@ -186,12 +226,13 @@ def refresh(): @app.route('/users//features', methods=['GET']) def get_user_features(user_id): - print("user_features", user_features) result = {} df_user = pd.DataFrame([{ 'user_id': user_id, }]) + user_features = user_features_store.get() + user_data = {} for name, df in user_features.items(): df = df[df['user_id'] == user_id] @@ -213,7 +254,8 @@ def predict(): if user_id is None: return jsonify({'error': 'Missing user_id'}), 400 - score = compute_score(user_id, item_features) + user_features = user_features_store.get() + score = compute_score(user_id, item_features, user_features) return jsonify({'score': score}) except Exception as e: app.logger.error(f"exception in predict endpoint: {request.get_json()}\n{e}") @@ -224,24 +266,18 @@ def predict(): def batch(): start = timer() try: - result = {} data = request.get_json() app.logger.info(f"batch scoring request: {data}") - user_id = data.get('user_id') items = data.get('items') + user_id = data.get('user_id') + if user_id == None: + return jsonify({'error': 'no user_id supplied'}), 400 + if len(items) > 101: + return jsonify({'error': f'too many items: {len(items)}'}), 400 + result = parallel_compute_scores(user_id, items) - if user_id is None: - return jsonify({'error': 'Missing user_id'}), 400 - - for key, item in items.items(): - print('key": ', key) - print('item: ', item) - library_item_id = item['library_item_id'] - result[library_item_id] = compute_score(user_id, item) - - end = timer() - print(f'time to compute batch of {len(items)} items (in seconds): {end - start}') + app.logger.info(f'time to compute batch of {len(items)} items (in seconds): {timer() - start}') return jsonify(result) except Exception as e: app.logger.error(f"exception in batch endpoint: {request.get_json()}\n{e}")