Add a RWLock to the user features storage for refreshing

This commit is contained in:
Jackson Harper
2024-07-05 12:41:08 +08:00
parent 65f5ff88ea
commit 39df17fce0

View File

@ -18,6 +18,24 @@ import dateutil.parser
from google.cloud import storage
from features.user_history import FEATURE_COLUMNS
import concurrent.futures
from threading import Lock, RLock
from collections import ChainMap
import copy
class ThreadSafeUserFeatures:
def __init__(self):
self._data = {}
self._lock = RLock()
def get(self):
with self._lock:
return dict(self._data)
def update(self, new_features):
with self._lock:
self._data.update(new_features)
app = Flask(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
@ -26,7 +44,8 @@ USER_HISTORY_PATH = 'user_features.pkl'
MODEL_PIPELINE_PATH = 'predict_read_model-v003.pkl'
pipeline = None
user_features = None
user_features_store = ThreadSafeUserFeatures()
# these buckets are used for reporting scores, we want to make sure
# there is decent diversity in the returned scores.
@ -93,23 +112,23 @@ def predict_proba_wrapper(X):
def refresh_data():
start = timer()
global pipeline
global explainer
global user_features
if os.getenv('LOAD_LOCAL_MODEL') == None:
print(f"loading data from {os.getenv('GCS_BUCKET')}")
app.logger.info(f"loading data from {os.getenv('GCS_BUCKET')}")
gcs_bucket_name = os.getenv('GCS_BUCKET')
download_from_gcs(gcs_bucket_name, f'data/features/{USER_HISTORY_PATH}', USER_HISTORY_PATH)
download_from_gcs(gcs_bucket_name, f'data/models/{MODEL_PIPELINE_PATH}', MODEL_PIPELINE_PATH)
pipeline = load_pipeline(MODEL_PIPELINE_PATH)
user_features = load_user_features(USER_HISTORY_PATH)
end = timer()
print('time to refresh data (in seconds):', end - start)
print('loaded pipeline:', pipeline)
print('loaded number of user_features:', len(user_features))
new_features = load_user_features(USER_HISTORY_PATH)
user_features_store.update(new_features)
app.logger.info(f'time to refresh data (in seconds): {timer() - start}')
app.logger.info(f'loaded pipeline: {pipeline}')
app.logger.info(f'loaded number of user_features: {len(new_features)}')
def compute_score(user_id, item_features):
interaction_score = compute_interaction_score(user_id, item_features)
def compute_score(user_id, item_features, user_features):
interaction_score = compute_interaction_score(user_id, item_features, user_features)
observe_score(interaction_score)
return {
'score': interaction_score,
@ -117,7 +136,7 @@ def compute_score(user_id, item_features):
}
def compute_interaction_score(user_id, item_features):
def compute_interaction_score(user_id, item_features, user_features):
start = timer()
original_url_host = urlparse(item_features.get('original_url')).netloc
df_test = pd.DataFrame([{
@ -160,15 +179,36 @@ def compute_interaction_score(user_id, item_features):
df_test = df_test.fillna(0)
df_predict = df_test[FEATURE_COLUMNS]
end = timer()
print('time to compute score (in seconds):', end - start)
interaction_score = pipeline.predict_proba(df_predict)
print("INTERACTION SCORE: ", interaction_score)
print('item_features:\n', df_predict[df_predict != 0].stack())
app.logger.info(f'INTERACTION SCORE: {interaction_score}')
app.logger.info(f'item_features:\n{df_predict[df_predict != 0].stack()}')
app.logger.info(f'time to compute score (in seconds): {timer() - start}')
return np.float64(interaction_score[0][1])
def process_parallel_item(user_id, key, item, user_features):
library_item_id = item['library_item_id']
return library_item_id, compute_score(user_id, item, user_features)
def parallel_compute_scores(user_id, items, max_workers=None):
user_features = user_features_store.get()
result = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_item = {executor.submit(process_parallel_item, user_id, key, item, user_features): (key, item)
for key, item in items.items()}
for future in concurrent.futures.as_completed(future_to_item):
key, item = future_to_item[future]
try:
library_item_id, score = future.result()
result[library_item_id] = score
except Exception as exc:
app.logger.error(f'Item {key} generated an exception: {exc}')
return result
@app.route('/_ah/health', methods=['GET'])
def ready():
return jsonify({'OK': 'yes'}), 200
@ -186,12 +226,13 @@ def refresh():
@app.route('/users/<user_id>/features', methods=['GET'])
def get_user_features(user_id):
print("user_features", user_features)
result = {}
df_user = pd.DataFrame([{
'user_id': user_id,
}])
user_features = user_features_store.get()
user_data = {}
for name, df in user_features.items():
df = df[df['user_id'] == user_id]
@ -213,7 +254,8 @@ def predict():
if user_id is None:
return jsonify({'error': 'Missing user_id'}), 400
score = compute_score(user_id, item_features)
user_features = user_features_store.get()
score = compute_score(user_id, item_features, user_features)
return jsonify({'score': score})
except Exception as e:
app.logger.error(f"exception in predict endpoint: {request.get_json()}\n{e}")
@ -224,24 +266,18 @@ def predict():
def batch():
start = timer()
try:
result = {}
data = request.get_json()
app.logger.info(f"batch scoring request: {data}")
user_id = data.get('user_id')
items = data.get('items')
user_id = data.get('user_id')
if user_id == None:
return jsonify({'error': 'no user_id supplied'}), 400
if len(items) > 101:
return jsonify({'error': f'too many items: {len(items)}'}), 400
result = parallel_compute_scores(user_id, items)
if user_id is None:
return jsonify({'error': 'Missing user_id'}), 400
for key, item in items.items():
print('key": ', key)
print('item: ', item)
library_item_id = item['library_item_id']
result[library_item_id] = compute_score(user_id, item)
end = timer()
print(f'time to compute batch of {len(items)} items (in seconds): {end - start}')
app.logger.info(f'time to compute batch of {len(items)} items (in seconds): {timer() - start}')
return jsonify(result)
except Exception as e:
app.logger.error(f"exception in batch endpoint: {request.get_json()}\n{e}")